mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
readme updates
This commit is contained in:
parent
88ccaa0ea5
commit
d6f389f86f
1 changed files with 294 additions and 54 deletions
|
|
@ -46,27 +46,90 @@ pip install datasets latex2sympy2_extended math_verify
|
|||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### IMPORTANT: Startup Order (Same for ALL Modes!)
|
||||
|
||||
All three training modes use the **same startup order**:
|
||||
|
||||
```
|
||||
1. Atropos API (python -m atroposlib.cli.run_api)
|
||||
↓
|
||||
2. vLLM Server (python example_trainer/vllm_api_server.py)
|
||||
↓ wait 60-90s for model load
|
||||
3. GRPO Trainer (python example_trainer/grpo.py)
|
||||
↓
|
||||
4. GSM8k Environment (python environments/gsm8k_server.py serve ...)
|
||||
```
|
||||
|
||||
**Why this order?**
|
||||
- The API must be running before the trainer or environment tries to connect
|
||||
- vLLM must be loaded before GSM8k tries to generate rollouts
|
||||
- The trainer must be running before GSM8k sends scored batches
|
||||
- GSM8k is started last because it immediately begins generating work
|
||||
|
||||
### GSM8k CLI Arguments (Required for All Modes)
|
||||
|
||||
When starting the GSM8k environment, always include these arguments:
|
||||
|
||||
```bash
|
||||
python environments/gsm8k_server.py serve \
|
||||
--slurm False \
|
||||
--openai.model_name Qwen/Qwen2.5-3B-Instruct \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.server_type vllm \
|
||||
--env.tokenizer_name Qwen/Qwen2.5-3B-Instruct
|
||||
```
|
||||
|
||||
| Argument | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `--slurm False` | Yes | Disable SLURM mode for local runs |
|
||||
| `--openai.model_name` | Yes | Model name (must match vLLM) |
|
||||
| `--openai.base_url` | Yes | vLLM server URL with `/v1` suffix |
|
||||
| `--openai.server_type vllm` | **Yes** | Must be `vllm` for `/generate` endpoint |
|
||||
| `--env.tokenizer_name` | Yes | Tokenizer for environment |
|
||||
|
||||
**Note:** `--openai.server_type vllm` is required because only the `VLLMServer` class supports `tokens_and_logprobs_completion` which GSM8k needs.
|
||||
|
||||
---
|
||||
|
||||
## Mode 1: Legacy (Checkpoint + Restart)
|
||||
|
||||
This is the simplest mode. The trainer periodically saves checkpoints and restarts vLLM.
|
||||
This mode saves checkpoints periodically and can restart vLLM with updated weights.
|
||||
|
||||
### Startup Order (Same for All Modes!)
|
||||
|
||||
```
|
||||
┌────────────────────────────────────────────────────────────────┐
|
||||
│ 1. Atropos API → Coordinates environments + trainer │
|
||||
│ 2. vLLM Server → Serves inference requests │
|
||||
│ 3. GRPO Trainer → Trains model, fetches batches │
|
||||
│ 4. GSM8k Environment → Generates problems, scores rollouts │
|
||||
└────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Step-by-Step Guide
|
||||
|
||||
**Terminal 1: Start the Atropos API**
|
||||
**Step 1: Start the Atropos API**
|
||||
```bash
|
||||
cd atropos
|
||||
run-api
|
||||
python -m atroposlib.cli.run_api > api.log 2>&1 &
|
||||
sleep 5
|
||||
```
|
||||
|
||||
**Terminal 2: Start the GSM8k Environment**
|
||||
**Step 2: Start the vLLM Server**
|
||||
```bash
|
||||
cd atropos
|
||||
python environments/gsm8k_server.py serve --slurm False
|
||||
python example_trainer/vllm_api_server.py \
|
||||
--model Qwen/Qwen2.5-3B-Instruct \
|
||||
--port 9001 \
|
||||
--gpu-memory-utilization 0.30 \
|
||||
> vllm.log 2>&1 &
|
||||
sleep 90 # Wait for model to load
|
||||
|
||||
# Verify vLLM is ready
|
||||
curl -s http://localhost:9001/health && echo "vLLM ready!"
|
||||
```
|
||||
|
||||
**Terminal 3: Start the GRPO Trainer**
|
||||
**Step 3: Start the GRPO Trainer**
|
||||
```bash
|
||||
cd atropos
|
||||
python example_trainer/grpo.py \
|
||||
|
|
@ -74,86 +137,146 @@ python example_trainer/grpo.py \
|
|||
--weight-bridge-mode none \
|
||||
--training-steps 100 \
|
||||
--vllm-restart-interval 10 \
|
||||
--vllm-port 9001 \
|
||||
--vllm-gpu-memory-utilization 0.30 \
|
||||
--batch-size 2 \
|
||||
--gradient-accumulation-steps 16 \
|
||||
--lr 1e-5 \
|
||||
--save-path checkpoints_legacy \
|
||||
--use-wandb \
|
||||
--wandb-project gsm8k-grpo
|
||||
--wandb-project gsm8k-grpo \
|
||||
> trainer.log 2>&1 &
|
||||
sleep 10
|
||||
```
|
||||
|
||||
**Step 4: Start the GSM8k Environment**
|
||||
```bash
|
||||
cd atropos
|
||||
python environments/gsm8k_server.py serve \
|
||||
--slurm False \
|
||||
--openai.model_name Qwen/Qwen2.5-3B-Instruct \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.server_type vllm \
|
||||
--env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \
|
||||
> gsm8k.log 2>&1 &
|
||||
```
|
||||
|
||||
**Monitor Training:**
|
||||
```bash
|
||||
tail -f trainer.log
|
||||
```
|
||||
|
||||
### Quick Copy-Paste (All-in-One)
|
||||
|
||||
```bash
|
||||
cd atropos && \
|
||||
pkill -f "grpo.py|vllm|gsm8k|run_api" 2>/dev/null; sleep 3 && \
|
||||
python -m atroposlib.cli.run_api > api.log 2>&1 & sleep 5 && \
|
||||
python example_trainer/vllm_api_server.py --model Qwen/Qwen2.5-3B-Instruct --port 9001 --gpu-memory-utilization 0.30 > vllm.log 2>&1 & sleep 90 && \
|
||||
python example_trainer/grpo.py --model-name Qwen/Qwen2.5-3B-Instruct --weight-bridge-mode none --training-steps 100 --vllm-port 9001 --vllm-gpu-memory-utilization 0.30 --batch-size 2 --gradient-accumulation-steps 16 --lr 1e-5 --save-path checkpoints_legacy --use-wandb --wandb-project gsm8k-grpo > trainer.log 2>&1 & sleep 10 && \
|
||||
python environments/gsm8k_server.py serve --slurm False --openai.model_name Qwen/Qwen2.5-3B-Instruct --openai.base_url http://localhost:9001/v1 --openai.server_type vllm --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct > gsm8k.log 2>&1 & \
|
||||
tail -f trainer.log
|
||||
```
|
||||
|
||||
### What Happens
|
||||
|
||||
1. Trainer loads `Qwen/Qwen2.5-3B-Instruct` into GPU memory
|
||||
2. Trainer launches vLLM server on port 9001
|
||||
1. vLLM server starts and loads `Qwen/Qwen2.5-3B-Instruct`
|
||||
2. Trainer loads its own copy of the model for training
|
||||
3. GSM8k env sends problems → vLLM generates solutions → scores sent to API
|
||||
4. Trainer fetches scored batches from API, computes GRPO loss, updates weights
|
||||
5. Every 10 steps: save checkpoint → kill vLLM → restart vLLM with new weights
|
||||
5. Every N steps: save checkpoint (weights stay in sync via external vLLM)
|
||||
6. Repeat until done
|
||||
|
||||
### Pros & Cons
|
||||
|
||||
+ Simple, works out of the box
|
||||
+ Simple conceptually
|
||||
+ Easy to debug
|
||||
- 30-60 second sync latency per restart
|
||||
- 2x GPU memory (trainer + vLLM both load model)
|
||||
+ Uses custom vLLM server with full endpoint support
|
||||
- 2x GPU memory (trainer + vLLM both load model)
|
||||
- Requires external vLLM to be running
|
||||
|
||||
---
|
||||
|
||||
## Mode 2: Shared vLLM Bridge (In-Place Updates)
|
||||
|
||||
This mode shares GPU tensors between trainer and vLLM. Updates happen instantly.
|
||||
This mode uses an HTTP-based notification system. The trainer notifies vLLM after weight updates.
|
||||
|
||||
### Step-by-Step Guide
|
||||
|
||||
**Terminal 1: Start the Atropos API**
|
||||
**Step 1: Start the Atropos API**
|
||||
```bash
|
||||
cd atropos
|
||||
run-api
|
||||
python -m atroposlib.cli.run_api > api.log 2>&1 &
|
||||
sleep 5
|
||||
```
|
||||
|
||||
**Terminal 2: Set up environment variables and start vLLM with bridge support**
|
||||
```bash
|
||||
cd atropos
|
||||
export LOGDIR=/tmp/atropos_bridge
|
||||
export NUM_INFERENCE_NODES=0 # Single-node local mode
|
||||
export MASTER_ADDR=localhost
|
||||
export MASTER_PORT=26756
|
||||
|
||||
mkdir -p $LOGDIR
|
||||
|
||||
# Start the custom vLLM server with bridge endpoints
|
||||
python example_trainer/vllm_api_server.py \
|
||||
--model Qwen/Qwen2.5-3B-Instruct \
|
||||
--port 9001 \
|
||||
--gpu-memory-utilization 0.45
|
||||
```
|
||||
|
||||
**Terminal 3: Start the GSM8k Environment**
|
||||
```bash
|
||||
cd atropos
|
||||
python environments/gsm8k_server.py serve --slurm False
|
||||
```
|
||||
|
||||
**Terminal 4: Start the GRPO Trainer in shared mode**
|
||||
**Step 2: Start the vLLM Server with Bridge Support**
|
||||
```bash
|
||||
cd atropos
|
||||
export LOGDIR=/tmp/atropos_bridge
|
||||
export NUM_INFERENCE_NODES=0
|
||||
mkdir -p $LOGDIR
|
||||
|
||||
python example_trainer/vllm_api_server.py \
|
||||
--model Qwen/Qwen2.5-3B-Instruct \
|
||||
--port 9001 \
|
||||
--gpu-memory-utilization 0.30 \
|
||||
> vllm.log 2>&1 &
|
||||
sleep 90
|
||||
|
||||
curl -s http://localhost:9001/health && echo "vLLM ready!"
|
||||
```
|
||||
|
||||
**Step 3: Start the GRPO Trainer in Shared Mode**
|
||||
```bash
|
||||
cd atropos
|
||||
export LOGDIR=/tmp/atropos_bridge
|
||||
export NUM_INFERENCE_NODES=0
|
||||
export MASTER_ADDR=localhost
|
||||
export MASTER_PORT=26756
|
||||
|
||||
python example_trainer/grpo.py \
|
||||
--model-name Qwen/Qwen2.5-3B-Instruct \
|
||||
--weight-bridge-mode shared_vllm \
|
||||
--trainer-rank 0 \
|
||||
--world-size 1 \
|
||||
--num-inference-nodes 0 \
|
||||
--training-steps 100 \
|
||||
--vllm-port 9001 \
|
||||
--batch-size 2 \
|
||||
--gradient-accumulation-steps 16 \
|
||||
--lr 1e-5 \
|
||||
--save-path checkpoints_shared \
|
||||
--use-wandb \
|
||||
--wandb-project gsm8k-grpo-shared
|
||||
--wandb-project gsm8k-grpo-shared \
|
||||
> trainer.log 2>&1 &
|
||||
sleep 10
|
||||
```
|
||||
|
||||
**Step 4: Start the GSM8k Environment**
|
||||
```bash
|
||||
cd atropos
|
||||
python environments/gsm8k_server.py serve \
|
||||
--slurm False \
|
||||
--openai.model_name Qwen/Qwen2.5-3B-Instruct \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.server_type vllm \
|
||||
--env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \
|
||||
> gsm8k.log 2>&1 &
|
||||
```
|
||||
|
||||
**Monitor Training:**
|
||||
```bash
|
||||
tail -f trainer.log
|
||||
```
|
||||
|
||||
### Quick Copy-Paste (All-in-One)
|
||||
|
||||
```bash
|
||||
cd atropos && \
|
||||
pkill -f "grpo.py|vllm|gsm8k|run_api" 2>/dev/null; sleep 3 && \
|
||||
export LOGDIR=/tmp/atropos_bridge && export NUM_INFERENCE_NODES=0 && mkdir -p $LOGDIR && \
|
||||
python -m atroposlib.cli.run_api > api.log 2>&1 & sleep 5 && \
|
||||
python example_trainer/vllm_api_server.py --model Qwen/Qwen2.5-3B-Instruct --port 9001 --gpu-memory-utilization 0.30 > vllm.log 2>&1 & sleep 90 && \
|
||||
python example_trainer/grpo.py --model-name Qwen/Qwen2.5-3B-Instruct --weight-bridge-mode shared_vllm --num-inference-nodes 0 --training-steps 100 --vllm-port 9001 --batch-size 2 --gradient-accumulation-steps 16 --lr 1e-5 --save-path checkpoints_shared --use-wandb --wandb-project gsm8k-grpo-shared > trainer.log 2>&1 & sleep 10 && \
|
||||
python environments/gsm8k_server.py serve --slurm False --openai.model_name Qwen/Qwen2.5-3B-Instruct --openai.base_url http://localhost:9001/v1 --openai.server_type vllm --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct > gsm8k.log 2>&1 & \
|
||||
tail -f trainer.log
|
||||
```
|
||||
|
||||
### What Happens (Local Mode - num_inference_nodes=0)
|
||||
|
|
@ -203,19 +326,27 @@ This mode trains only LoRA adapter weights. Much smaller checkpoints, faster ite
|
|||
|
||||
### Step-by-Step Guide
|
||||
|
||||
**Terminal 1: Start the Atropos API**
|
||||
**Step 1: Start the Atropos API**
|
||||
```bash
|
||||
cd atropos
|
||||
run-api
|
||||
python -m atroposlib.cli.run_api > api.log 2>&1 &
|
||||
sleep 5
|
||||
```
|
||||
|
||||
**Terminal 2: Start the GSM8k Environment**
|
||||
**Step 2: Start the vLLM Server (Required for LoRA Hot-Swap)**
|
||||
```bash
|
||||
cd atropos
|
||||
python environments/gsm8k_server.py serve --slurm False
|
||||
python example_trainer/vllm_api_server.py \
|
||||
--model Qwen/Qwen2.5-3B-Instruct \
|
||||
--port 9001 \
|
||||
--gpu-memory-utilization 0.30 \
|
||||
> vllm.log 2>&1 &
|
||||
sleep 90
|
||||
|
||||
curl -s http://localhost:9001/health && echo "vLLM ready!"
|
||||
```
|
||||
|
||||
**Terminal 3: Start the GRPO Trainer in LoRA mode**
|
||||
**Step 3: Start the GRPO Trainer in LoRA Mode**
|
||||
```bash
|
||||
cd atropos
|
||||
python example_trainer/grpo.py \
|
||||
|
|
@ -226,12 +357,44 @@ python example_trainer/grpo.py \
|
|||
--lora-dropout 0.05 \
|
||||
--lora-target-modules q_proj v_proj \
|
||||
--training-steps 100 \
|
||||
--vllm-restart-interval 20 \
|
||||
--vllm-port 9001 \
|
||||
--batch-size 2 \
|
||||
--gradient-accumulation-steps 16 \
|
||||
--lr 1e-4 \
|
||||
--save-path checkpoints_lora \
|
||||
--use-wandb \
|
||||
--wandb-project gsm8k-grpo-lora
|
||||
--wandb-project gsm8k-grpo-lora \
|
||||
> trainer.log 2>&1 &
|
||||
sleep 10
|
||||
```
|
||||
|
||||
**Step 4: Start the GSM8k Environment**
|
||||
```bash
|
||||
cd atropos
|
||||
python environments/gsm8k_server.py serve \
|
||||
--slurm False \
|
||||
--openai.model_name Qwen/Qwen2.5-3B-Instruct \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.server_type vllm \
|
||||
--env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \
|
||||
> gsm8k.log 2>&1 &
|
||||
```
|
||||
|
||||
**Monitor Training:**
|
||||
```bash
|
||||
tail -f trainer.log
|
||||
```
|
||||
|
||||
### Quick Copy-Paste (All-in-One)
|
||||
|
||||
```bash
|
||||
cd atropos && \
|
||||
pkill -f "grpo.py|vllm|gsm8k|run_api" 2>/dev/null; sleep 3 && \
|
||||
python -m atroposlib.cli.run_api > api.log 2>&1 & sleep 5 && \
|
||||
python example_trainer/vllm_api_server.py --model Qwen/Qwen2.5-3B-Instruct --port 9001 --gpu-memory-utilization 0.30 > vllm.log 2>&1 & sleep 90 && \
|
||||
python example_trainer/grpo.py --model-name Qwen/Qwen2.5-3B-Instruct --weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32 --lora-dropout 0.05 --training-steps 100 --vllm-port 9001 --batch-size 2 --gradient-accumulation-steps 16 --lr 1e-4 --save-path checkpoints_lora --use-wandb --wandb-project gsm8k-grpo-lora > trainer.log 2>&1 & sleep 10 && \
|
||||
python environments/gsm8k_server.py serve --slurm False --openai.model_name Qwen/Qwen2.5-3B-Instruct --openai.base_url http://localhost:9001/v1 --openai.server_type vllm --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct > gsm8k.log 2>&1 & \
|
||||
tail -f trainer.log
|
||||
```
|
||||
|
||||
### What Happens
|
||||
|
|
@ -322,6 +485,37 @@ python example_trainer/grpo.py --help
|
|||
|
||||
---
|
||||
|
||||
## Shutdown / Cleanup
|
||||
|
||||
### Stop All Processes
|
||||
|
||||
```bash
|
||||
# Graceful shutdown
|
||||
pkill -f "gsm8k_server"
|
||||
sleep 2
|
||||
pkill -f "grpo.py"
|
||||
sleep 2
|
||||
pkill -f "vllm_api_server"
|
||||
sleep 2
|
||||
pkill -f "run_api"
|
||||
|
||||
echo "All processes stopped"
|
||||
```
|
||||
|
||||
### Check Running Processes
|
||||
|
||||
```bash
|
||||
ps aux | grep -E "(grpo|vllm|gsm8k|run_api)" | grep -v grep
|
||||
```
|
||||
|
||||
### Check GPU Usage
|
||||
|
||||
```bash
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "CUDA out of memory"
|
||||
|
|
@ -330,7 +524,8 @@ Try reducing:
|
|||
```bash
|
||||
--batch-size 1 \
|
||||
--gradient-accumulation-steps 64 \
|
||||
--seq-len 1024
|
||||
--seq-len 1024 \
|
||||
--vllm-gpu-memory-utilization 0.25
|
||||
```
|
||||
|
||||
Or use LoRA mode which uses less memory.
|
||||
|
|
@ -339,7 +534,7 @@ Or use LoRA mode which uses less memory.
|
|||
|
||||
Make sure the API is running:
|
||||
```bash
|
||||
run-api # In a separate terminal
|
||||
python -m atroposlib.cli.run_api
|
||||
```
|
||||
|
||||
### vLLM fails to start
|
||||
|
|
@ -351,9 +546,38 @@ lsof -i :9001
|
|||
|
||||
Kill existing processes or use a different port:
|
||||
```bash
|
||||
pkill -f "vllm_api_server"
|
||||
# or use different port:
|
||||
--vllm-port 9002
|
||||
```
|
||||
|
||||
### "NotImplementedError" or "404 Not Found" on `/generate`
|
||||
|
||||
This means you're using the wrong server type. Make sure:
|
||||
|
||||
1. You started `vllm_api_server.py` (NOT standard `vllm serve`)
|
||||
2. GSM8k uses `--openai.server_type vllm` (NOT `openai`)
|
||||
|
||||
```bash
|
||||
# CORRECT
|
||||
python example_trainer/vllm_api_server.py --model ... --port 9001
|
||||
python environments/gsm8k_server.py serve --openai.server_type vllm ...
|
||||
|
||||
# WRONG - standard vLLM doesn't have /generate endpoint
|
||||
python -m vllm.entrypoints.openai.api_server --model ... --port 9001
|
||||
```
|
||||
|
||||
### "Free memory on device is less than desired GPU memory utilization"
|
||||
|
||||
Lower the vLLM memory utilization:
|
||||
|
||||
```bash
|
||||
python example_trainer/vllm_api_server.py \
|
||||
--model Qwen/Qwen2.5-3B-Instruct \
|
||||
--port 9001 \
|
||||
--gpu-memory-utilization 0.25 # Lower this
|
||||
```
|
||||
|
||||
### Bridge mode: "Parameter mapping file not found"
|
||||
|
||||
Ensure `$LOGDIR` is set and vLLM server is running:
|
||||
|
|
@ -369,6 +593,22 @@ Install PEFT:
|
|||
pip install peft
|
||||
```
|
||||
|
||||
### No trajectories collected / Workers timing out
|
||||
|
||||
Check that all services are running in the correct order:
|
||||
```bash
|
||||
# Check processes
|
||||
ps aux | grep -E "(run_api|vllm|grpo|gsm8k)" | grep -v grep
|
||||
|
||||
# Check vLLM health
|
||||
curl http://localhost:9001/health
|
||||
|
||||
# Check API health
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
If vLLM isn't ready, wait longer before starting GSM8k.
|
||||
|
||||
---
|
||||
|
||||
## Checkpoint Locations
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue