readme updates

This commit is contained in:
Jai Suphavadeeprasit 2025-12-10 15:50:48 -05:00
parent 88ccaa0ea5
commit d6f389f86f

View file

@ -46,27 +46,90 @@ pip install datasets latex2sympy2_extended math_verify
└─────────────────────────────────────────────────────────────────┘
```
### IMPORTANT: Startup Order (Same for ALL Modes!)
All three training modes use the **same startup order**:
```
1. Atropos API (python -m atroposlib.cli.run_api)
2. vLLM Server (python example_trainer/vllm_api_server.py)
↓ wait 60-90s for model load
3. GRPO Trainer (python example_trainer/grpo.py)
4. GSM8k Environment (python environments/gsm8k_server.py serve ...)
```
**Why this order?**
- The API must be running before the trainer or environment tries to connect
- vLLM must be loaded before GSM8k tries to generate rollouts
- The trainer must be running before GSM8k sends scored batches
- GSM8k is started last because it immediately begins generating work
### GSM8k CLI Arguments (Required for All Modes)
When starting the GSM8k environment, always include these arguments:
```bash
python environments/gsm8k_server.py serve \
--slurm False \
--openai.model_name Qwen/Qwen2.5-3B-Instruct \
--openai.base_url http://localhost:9001/v1 \
--openai.server_type vllm \
--env.tokenizer_name Qwen/Qwen2.5-3B-Instruct
```
| Argument | Required | Description |
|----------|----------|-------------|
| `--slurm False` | Yes | Disable SLURM mode for local runs |
| `--openai.model_name` | Yes | Model name (must match vLLM) |
| `--openai.base_url` | Yes | vLLM server URL with `/v1` suffix |
| `--openai.server_type vllm` | **Yes** | Must be `vllm` for `/generate` endpoint |
| `--env.tokenizer_name` | Yes | Tokenizer for environment |
**Note:** `--openai.server_type vllm` is required because only the `VLLMServer` class supports `tokens_and_logprobs_completion` which GSM8k needs.
---
## Mode 1: Legacy (Checkpoint + Restart)
This is the simplest mode. The trainer periodically saves checkpoints and restarts vLLM.
This mode saves checkpoints periodically and can restart vLLM with updated weights.
### Startup Order (Same for All Modes!)
```
┌────────────────────────────────────────────────────────────────┐
│ 1. Atropos API → Coordinates environments + trainer │
│ 2. vLLM Server → Serves inference requests │
│ 3. GRPO Trainer → Trains model, fetches batches │
│ 4. GSM8k Environment → Generates problems, scores rollouts │
└────────────────────────────────────────────────────────────────┘
```
### Step-by-Step Guide
**Terminal 1: Start the Atropos API**
**Step 1: Start the Atropos API**
```bash
cd atropos
run-api
python -m atroposlib.cli.run_api > api.log 2>&1 &
sleep 5
```
**Terminal 2: Start the GSM8k Environment**
**Step 2: Start the vLLM Server**
```bash
cd atropos
python environments/gsm8k_server.py serve --slurm False
python example_trainer/vllm_api_server.py \
--model Qwen/Qwen2.5-3B-Instruct \
--port 9001 \
--gpu-memory-utilization 0.30 \
> vllm.log 2>&1 &
sleep 90 # Wait for model to load
# Verify vLLM is ready
curl -s http://localhost:9001/health && echo "vLLM ready!"
```
**Terminal 3: Start the GRPO Trainer**
**Step 3: Start the GRPO Trainer**
```bash
cd atropos
python example_trainer/grpo.py \
@ -74,86 +137,146 @@ python example_trainer/grpo.py \
--weight-bridge-mode none \
--training-steps 100 \
--vllm-restart-interval 10 \
--vllm-port 9001 \
--vllm-gpu-memory-utilization 0.30 \
--batch-size 2 \
--gradient-accumulation-steps 16 \
--lr 1e-5 \
--save-path checkpoints_legacy \
--use-wandb \
--wandb-project gsm8k-grpo
--wandb-project gsm8k-grpo \
> trainer.log 2>&1 &
sleep 10
```
**Step 4: Start the GSM8k Environment**
```bash
cd atropos
python environments/gsm8k_server.py serve \
--slurm False \
--openai.model_name Qwen/Qwen2.5-3B-Instruct \
--openai.base_url http://localhost:9001/v1 \
--openai.server_type vllm \
--env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \
> gsm8k.log 2>&1 &
```
**Monitor Training:**
```bash
tail -f trainer.log
```
### Quick Copy-Paste (All-in-One)
```bash
cd atropos && \
pkill -f "grpo.py|vllm|gsm8k|run_api" 2>/dev/null; sleep 3 && \
python -m atroposlib.cli.run_api > api.log 2>&1 & sleep 5 && \
python example_trainer/vllm_api_server.py --model Qwen/Qwen2.5-3B-Instruct --port 9001 --gpu-memory-utilization 0.30 > vllm.log 2>&1 & sleep 90 && \
python example_trainer/grpo.py --model-name Qwen/Qwen2.5-3B-Instruct --weight-bridge-mode none --training-steps 100 --vllm-port 9001 --vllm-gpu-memory-utilization 0.30 --batch-size 2 --gradient-accumulation-steps 16 --lr 1e-5 --save-path checkpoints_legacy --use-wandb --wandb-project gsm8k-grpo > trainer.log 2>&1 & sleep 10 && \
python environments/gsm8k_server.py serve --slurm False --openai.model_name Qwen/Qwen2.5-3B-Instruct --openai.base_url http://localhost:9001/v1 --openai.server_type vllm --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct > gsm8k.log 2>&1 & \
tail -f trainer.log
```
### What Happens
1. Trainer loads `Qwen/Qwen2.5-3B-Instruct` into GPU memory
2. Trainer launches vLLM server on port 9001
1. vLLM server starts and loads `Qwen/Qwen2.5-3B-Instruct`
2. Trainer loads its own copy of the model for training
3. GSM8k env sends problems → vLLM generates solutions → scores sent to API
4. Trainer fetches scored batches from API, computes GRPO loss, updates weights
5. Every 10 steps: save checkpoint → kill vLLM → restart vLLM with new weights
5. Every N steps: save checkpoint (weights stay in sync via external vLLM)
6. Repeat until done
### Pros & Cons
+ Simple, works out of the box
+ Simple conceptually
+ Easy to debug
- 30-60 second sync latency per restart
- 2x GPU memory (trainer + vLLM both load model)
+ Uses custom vLLM server with full endpoint support
- 2x GPU memory (trainer + vLLM both load model)
- Requires external vLLM to be running
---
## Mode 2: Shared vLLM Bridge (In-Place Updates)
This mode shares GPU tensors between trainer and vLLM. Updates happen instantly.
This mode uses an HTTP-based notification system. The trainer notifies vLLM after weight updates.
### Step-by-Step Guide
**Terminal 1: Start the Atropos API**
**Step 1: Start the Atropos API**
```bash
cd atropos
run-api
python -m atroposlib.cli.run_api > api.log 2>&1 &
sleep 5
```
**Terminal 2: Set up environment variables and start vLLM with bridge support**
```bash
cd atropos
export LOGDIR=/tmp/atropos_bridge
export NUM_INFERENCE_NODES=0 # Single-node local mode
export MASTER_ADDR=localhost
export MASTER_PORT=26756
mkdir -p $LOGDIR
# Start the custom vLLM server with bridge endpoints
python example_trainer/vllm_api_server.py \
--model Qwen/Qwen2.5-3B-Instruct \
--port 9001 \
--gpu-memory-utilization 0.45
```
**Terminal 3: Start the GSM8k Environment**
```bash
cd atropos
python environments/gsm8k_server.py serve --slurm False
```
**Terminal 4: Start the GRPO Trainer in shared mode**
**Step 2: Start the vLLM Server with Bridge Support**
```bash
cd atropos
export LOGDIR=/tmp/atropos_bridge
export NUM_INFERENCE_NODES=0
mkdir -p $LOGDIR
python example_trainer/vllm_api_server.py \
--model Qwen/Qwen2.5-3B-Instruct \
--port 9001 \
--gpu-memory-utilization 0.30 \
> vllm.log 2>&1 &
sleep 90
curl -s http://localhost:9001/health && echo "vLLM ready!"
```
**Step 3: Start the GRPO Trainer in Shared Mode**
```bash
cd atropos
export LOGDIR=/tmp/atropos_bridge
export NUM_INFERENCE_NODES=0
export MASTER_ADDR=localhost
export MASTER_PORT=26756
python example_trainer/grpo.py \
--model-name Qwen/Qwen2.5-3B-Instruct \
--weight-bridge-mode shared_vllm \
--trainer-rank 0 \
--world-size 1 \
--num-inference-nodes 0 \
--training-steps 100 \
--vllm-port 9001 \
--batch-size 2 \
--gradient-accumulation-steps 16 \
--lr 1e-5 \
--save-path checkpoints_shared \
--use-wandb \
--wandb-project gsm8k-grpo-shared
--wandb-project gsm8k-grpo-shared \
> trainer.log 2>&1 &
sleep 10
```
**Step 4: Start the GSM8k Environment**
```bash
cd atropos
python environments/gsm8k_server.py serve \
--slurm False \
--openai.model_name Qwen/Qwen2.5-3B-Instruct \
--openai.base_url http://localhost:9001/v1 \
--openai.server_type vllm \
--env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \
> gsm8k.log 2>&1 &
```
**Monitor Training:**
```bash
tail -f trainer.log
```
### Quick Copy-Paste (All-in-One)
```bash
cd atropos && \
pkill -f "grpo.py|vllm|gsm8k|run_api" 2>/dev/null; sleep 3 && \
export LOGDIR=/tmp/atropos_bridge && export NUM_INFERENCE_NODES=0 && mkdir -p $LOGDIR && \
python -m atroposlib.cli.run_api > api.log 2>&1 & sleep 5 && \
python example_trainer/vllm_api_server.py --model Qwen/Qwen2.5-3B-Instruct --port 9001 --gpu-memory-utilization 0.30 > vllm.log 2>&1 & sleep 90 && \
python example_trainer/grpo.py --model-name Qwen/Qwen2.5-3B-Instruct --weight-bridge-mode shared_vllm --num-inference-nodes 0 --training-steps 100 --vllm-port 9001 --batch-size 2 --gradient-accumulation-steps 16 --lr 1e-5 --save-path checkpoints_shared --use-wandb --wandb-project gsm8k-grpo-shared > trainer.log 2>&1 & sleep 10 && \
python environments/gsm8k_server.py serve --slurm False --openai.model_name Qwen/Qwen2.5-3B-Instruct --openai.base_url http://localhost:9001/v1 --openai.server_type vllm --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct > gsm8k.log 2>&1 & \
tail -f trainer.log
```
### What Happens (Local Mode - num_inference_nodes=0)
@ -203,19 +326,27 @@ This mode trains only LoRA adapter weights. Much smaller checkpoints, faster ite
### Step-by-Step Guide
**Terminal 1: Start the Atropos API**
**Step 1: Start the Atropos API**
```bash
cd atropos
run-api
python -m atroposlib.cli.run_api > api.log 2>&1 &
sleep 5
```
**Terminal 2: Start the GSM8k Environment**
**Step 2: Start the vLLM Server (Required for LoRA Hot-Swap)**
```bash
cd atropos
python environments/gsm8k_server.py serve --slurm False
python example_trainer/vllm_api_server.py \
--model Qwen/Qwen2.5-3B-Instruct \
--port 9001 \
--gpu-memory-utilization 0.30 \
> vllm.log 2>&1 &
sleep 90
curl -s http://localhost:9001/health && echo "vLLM ready!"
```
**Terminal 3: Start the GRPO Trainer in LoRA mode**
**Step 3: Start the GRPO Trainer in LoRA Mode**
```bash
cd atropos
python example_trainer/grpo.py \
@ -226,12 +357,44 @@ python example_trainer/grpo.py \
--lora-dropout 0.05 \
--lora-target-modules q_proj v_proj \
--training-steps 100 \
--vllm-restart-interval 20 \
--vllm-port 9001 \
--batch-size 2 \
--gradient-accumulation-steps 16 \
--lr 1e-4 \
--save-path checkpoints_lora \
--use-wandb \
--wandb-project gsm8k-grpo-lora
--wandb-project gsm8k-grpo-lora \
> trainer.log 2>&1 &
sleep 10
```
**Step 4: Start the GSM8k Environment**
```bash
cd atropos
python environments/gsm8k_server.py serve \
--slurm False \
--openai.model_name Qwen/Qwen2.5-3B-Instruct \
--openai.base_url http://localhost:9001/v1 \
--openai.server_type vllm \
--env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \
> gsm8k.log 2>&1 &
```
**Monitor Training:**
```bash
tail -f trainer.log
```
### Quick Copy-Paste (All-in-One)
```bash
cd atropos && \
pkill -f "grpo.py|vllm|gsm8k|run_api" 2>/dev/null; sleep 3 && \
python -m atroposlib.cli.run_api > api.log 2>&1 & sleep 5 && \
python example_trainer/vllm_api_server.py --model Qwen/Qwen2.5-3B-Instruct --port 9001 --gpu-memory-utilization 0.30 > vllm.log 2>&1 & sleep 90 && \
python example_trainer/grpo.py --model-name Qwen/Qwen2.5-3B-Instruct --weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32 --lora-dropout 0.05 --training-steps 100 --vllm-port 9001 --batch-size 2 --gradient-accumulation-steps 16 --lr 1e-4 --save-path checkpoints_lora --use-wandb --wandb-project gsm8k-grpo-lora > trainer.log 2>&1 & sleep 10 && \
python environments/gsm8k_server.py serve --slurm False --openai.model_name Qwen/Qwen2.5-3B-Instruct --openai.base_url http://localhost:9001/v1 --openai.server_type vllm --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct > gsm8k.log 2>&1 & \
tail -f trainer.log
```
### What Happens
@ -322,6 +485,37 @@ python example_trainer/grpo.py --help
---
## Shutdown / Cleanup
### Stop All Processes
```bash
# Graceful shutdown
pkill -f "gsm8k_server"
sleep 2
pkill -f "grpo.py"
sleep 2
pkill -f "vllm_api_server"
sleep 2
pkill -f "run_api"
echo "All processes stopped"
```
### Check Running Processes
```bash
ps aux | grep -E "(grpo|vllm|gsm8k|run_api)" | grep -v grep
```
### Check GPU Usage
```bash
nvidia-smi
```
---
## Troubleshooting
### "CUDA out of memory"
@ -330,7 +524,8 @@ Try reducing:
```bash
--batch-size 1 \
--gradient-accumulation-steps 64 \
--seq-len 1024
--seq-len 1024 \
--vllm-gpu-memory-utilization 0.25
```
Or use LoRA mode which uses less memory.
@ -339,7 +534,7 @@ Or use LoRA mode which uses less memory.
Make sure the API is running:
```bash
run-api # In a separate terminal
python -m atroposlib.cli.run_api
```
### vLLM fails to start
@ -351,9 +546,38 @@ lsof -i :9001
Kill existing processes or use a different port:
```bash
pkill -f "vllm_api_server"
# or use different port:
--vllm-port 9002
```
### "NotImplementedError" or "404 Not Found" on `/generate`
This means you're using the wrong server type. Make sure:
1. You started `vllm_api_server.py` (NOT standard `vllm serve`)
2. GSM8k uses `--openai.server_type vllm` (NOT `openai`)
```bash
# CORRECT
python example_trainer/vllm_api_server.py --model ... --port 9001
python environments/gsm8k_server.py serve --openai.server_type vllm ...
# WRONG - standard vLLM doesn't have /generate endpoint
python -m vllm.entrypoints.openai.api_server --model ... --port 9001
```
### "Free memory on device is less than desired GPU memory utilization"
Lower the vLLM memory utilization:
```bash
python example_trainer/vllm_api_server.py \
--model Qwen/Qwen2.5-3B-Instruct \
--port 9001 \
--gpu-memory-utilization 0.25 # Lower this
```
### Bridge mode: "Parameter mapping file not found"
Ensure `$LOGDIR` is set and vLLM server is running:
@ -369,6 +593,22 @@ Install PEFT:
pip install peft
```
### No trajectories collected / Workers timing out
Check that all services are running in the correct order:
```bash
# Check processes
ps aux | grep -E "(run_api|vllm|grpo|gsm8k)" | grep -v grep
# Check vLLM health
curl http://localhost:9001/health
# Check API health
curl http://localhost:8000/health
```
If vLLM isn't ready, wait longer before starting GSM8k.
---
## Checkpoint Locations