12 KiB
GRPO Example Trainer
This directory contains an example script (grpo.py) demonstrating how to integrate a custom training loop with the Atropos API for reinforcement learning using the GRPO (Group Relative Policy Optimization) algorithm.
Training Modes
The trainer supports three weight synchronization modes:
| Mode | Description | Sync Latency | Best For |
|---|---|---|---|
Legacy (none) |
Save checkpoints, restart vLLM | ~30-60 seconds | Simple setups, debugging |
Shared vLLM (shared_vllm) |
Direct shared memory updates | ~0 ms | Production, maximum throughput |
LoRA (lora_only) |
Train adapters, hot-swap | ~1-5 seconds | Memory-constrained, fast iteration |
Quick Start with GSM8k
Prerequisites
# Install dependencies
pip install -r example_trainer/requirements.txt
# Install GSM8k environment dependencies
pip install datasets latex2sympy2_extended math_verify
Architecture Overview
┌─────────────────────────────────────────────────────────────────┐
│ Training Setup │
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
│ │ GSM8k Env │───▶│ Atropos API │◀───│ GRPO Trainer │ │
│ │ (problems) │ │ (batching) │ │ (optimization) │ │
│ └─────────────┘ └─────────────┘ └─────────────────────┘ │
│ │ │ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ vLLM Inference Server │ │
│ │ (generates rollouts for scoring) │ │
│ └─────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
Mode 1: Legacy (Checkpoint + Restart)
This is the simplest mode. The trainer periodically saves checkpoints and restarts vLLM.
Step-by-Step Guide
Terminal 1: Start the Atropos API
cd atropos
run-api
Terminal 2: Start the GSM8k Environment
cd atropos
python environments/gsm8k_server.py serve --slurm False
Terminal 3: Start the GRPO Trainer
cd atropos
python example_trainer/grpo.py \
--model-name Qwen/Qwen2.5-3B-Instruct \
--weight-bridge-mode none \
--training-steps 100 \
--vllm-restart-interval 10 \
--batch-size 2 \
--gradient-accumulation-steps 16 \
--lr 1e-5 \
--use-wandb \
--wandb-project gsm8k-grpo
What Happens
- Trainer loads
Qwen/Qwen2.5-3B-Instructinto GPU memory - Trainer launches vLLM server on port 9001
- GSM8k env sends problems → vLLM generates solutions → scores sent to API
- Trainer fetches scored batches from API, computes GRPO loss, updates weights
- Every 10 steps: save checkpoint → kill vLLM → restart vLLM with new weights
- Repeat until done
Pros & Cons
- Simple, works out of the box
- Easy to debug
- 30-60 second sync latency per restart
- 2x GPU memory (trainer + vLLM both load model)
Mode 2: Shared vLLM Bridge (In-Place Updates)
This mode shares GPU tensors between trainer and vLLM. Updates happen instantly.
Step-by-Step Guide
Terminal 1: Start the Atropos API
cd atropos
run-api
Terminal 2: Set up environment variables and start vLLM with bridge support
cd atropos
export LOGDIR=/tmp/atropos_bridge
export NUM_INFERENCE_NODES=0 # Single-node local mode
export MASTER_ADDR=localhost
export MASTER_PORT=26756
mkdir -p $LOGDIR
# Start the custom vLLM server with bridge endpoints
python example_trainer/vllm_api_server.py \
--model Qwen/Qwen2.5-3B-Instruct \
--port 9001 \
--gpu-memory-utilization 0.45
Terminal 3: Start the GSM8k Environment
cd atropos
python environments/gsm8k_server.py serve --slurm False
Terminal 4: Start the GRPO Trainer in shared mode
cd atropos
export LOGDIR=/tmp/atropos_bridge
export NUM_INFERENCE_NODES=0
export MASTER_ADDR=localhost
export MASTER_PORT=26756
python example_trainer/grpo.py \
--model-name Qwen/Qwen2.5-3B-Instruct \
--weight-bridge-mode shared_vllm \
--trainer-rank 0 \
--world-size 1 \
--num-inference-nodes 0 \
--training-steps 100 \
--batch-size 2 \
--gradient-accumulation-steps 16 \
--lr 1e-5 \
--use-wandb \
--wandb-project gsm8k-grpo-shared
What Happens (Local Mode - num_inference_nodes=0)
- vLLM server starts on port 9001
- Trainer initializes bridge in LOCAL MODE (HTTP-based, no NCCL)
- Trainer loads its own model copy and trains normally
- After each
optimizer.step():bridge.notify_update()sends HTTP POST to vLLM- Periodic checkpoint saves sync weights to disk
- Much simpler than distributed mode!
What Happens (Distributed Mode - num_inference_nodes>0)
- vLLM server starts, writes parameter mapping to
$LOGDIR/vllm_bridge_config.json - Trainer reads mapping, joins NCCL process group with vLLM
- Trainer's model parameters point to vLLM's GPU tensors (shared memory)
- Training loop:
- Forward pass uses shared weights
optimizer.step()modifies shared tensors in-placebridge.notify_update()broadcasts via Gloo- vLLM immediately uses new weights for next inference
- No restarts needed!
Environment Variables
| Variable | Description | Example |
|---|---|---|
LOGDIR |
Directory for bridge coordination files | /tmp/atropos_bridge |
NUM_INFERENCE_NODES |
Number of vLLM nodes (0 = local) | 0 |
MASTER_ADDR |
Rendezvous address | localhost |
MASTER_PORT |
Rendezvous port | 26756 |
Pros & Cons
- ~0ms sync latency (instant updates)
- 1x GPU memory (shared tensors)
- Maximum training throughput
- More complex setup
- Requires compatible vLLM version
Mode 3: LoRA Adapters (Hot-Swap)
This mode trains only LoRA adapter weights. Much smaller checkpoints, faster iteration.
Step-by-Step Guide
Terminal 1: Start the Atropos API
cd atropos
run-api
Terminal 2: Start the GSM8k Environment
cd atropos
python environments/gsm8k_server.py serve --slurm False
Terminal 3: Start the GRPO Trainer in LoRA mode
cd atropos
python example_trainer/grpo.py \
--model-name Qwen/Qwen2.5-3B-Instruct \
--weight-bridge-mode lora_only \
--lora-r 16 \
--lora-alpha 32 \
--lora-dropout 0.05 \
--lora-target-modules q_proj v_proj \
--training-steps 100 \
--vllm-restart-interval 20 \
--batch-size 2 \
--gradient-accumulation-steps 16 \
--lr 1e-4 \
--use-wandb \
--wandb-project gsm8k-grpo-lora
What Happens
- Trainer loads base model, wraps with LoRA adapters (PEFT)
- Only adapter parameters are trainable (~0.1% of total)
- Training loop updates adapter weights only
- Every N steps: save adapter checkpoint (small, ~10-50MB)
- vLLM can hot-swap adapters via
/lora/loadendpoint
LoRA Configuration
| Option | Default | Description |
|---|---|---|
--lora-r |
16 | Rank of low-rank matrices |
--lora-alpha |
32 | Scaling factor (typically 2x rank) |
--lora-dropout |
0.05 | Dropout for regularization |
--lora-target-modules |
q_proj v_proj |
Which layers to adapt |
Common Target Module Combinations
# Minimal (fastest training)
--lora-target-modules q_proj v_proj
# Attention only
--lora-target-modules q_proj k_proj v_proj o_proj
# Full (most expressive)
--lora-target-modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj
Pros & Cons
- Much faster training (fewer parameters)
- Tiny checkpoints (~10-50MB vs ~6GB)
- Can hot-swap adapters without full restart
- Lower GPU memory (base model frozen)
- Less expressive than full fine-tuning
- May need higher learning rate
Configuration Reference
All CLI Options
python example_trainer/grpo.py --help
Core Training Options
| Option | Default | Description |
|---|---|---|
--model-name |
(required) | HuggingFace model ID |
--lr |
1e-5 |
Learning rate |
--training-steps |
10 |
Total optimization steps |
--batch-size |
2 |
Micro-batch size |
--gradient-accumulation-steps |
32 |
Gradient accumulation |
--seq-len |
2048 |
Max sequence length |
--save-path |
trained_model_checkpoints |
Checkpoint directory |
vLLM Options
| Option | Default | Description |
|---|---|---|
--vllm-port |
9001 |
vLLM server port |
--vllm-restart-interval |
3 |
Steps between syncs |
Weight Bridge Options
| Option | Default | Description |
|---|---|---|
--weight-bridge-mode |
none |
none, shared_vllm, or lora_only |
--trainer-rank |
0 |
Distributed rank |
--world-size |
1 |
Total processes |
--init-method |
env:// |
PyTorch distributed init |
--num-inference-nodes |
0 |
Number of vLLM nodes |
Logging Options
| Option | Default | Description |
|---|---|---|
--use-wandb |
False |
Enable W&B logging |
--wandb-project |
None |
W&B project name |
--wandb-group |
None |
W&B group name |
Troubleshooting
"CUDA out of memory"
Try reducing:
--batch-size 1 \
--gradient-accumulation-steps 64 \
--seq-len 1024
Or use LoRA mode which uses less memory.
"Connection refused" to Atropos API
Make sure the API is running:
run-api # In a separate terminal
vLLM fails to start
Check if port 9001 is in use:
lsof -i :9001
Kill existing processes or use a different port:
--vllm-port 9002
Bridge mode: "Parameter mapping file not found"
Ensure $LOGDIR is set and vLLM server is running:
export LOGDIR=/tmp/atropos_bridge
ls $LOGDIR/vllm_bridge_config.json
LoRA mode: "PEFT library not available"
Install PEFT:
pip install peft
Files in This Directory
| File | Description |
|---|---|
grpo.py |
Main trainer script with all modes |
vllm_api_server.py |
Custom vLLM server with bridge endpoints |
vllm_weight_bridge.py |
Shared memory bridge implementation |
requirements.txt |
Python dependencies |
README.md |
This documentation |
Example Runs
Quick Test (Legacy Mode)
# Minimal test to verify setup works
python example_trainer/grpo.py \
--model-name Qwen/Qwen2.5-0.5B-Instruct \
--training-steps 5 \
--batch-size 1 \
--gradient-accumulation-steps 4
Full GSM8k Training (LoRA Mode)
# Recommended for single-GPU training
python example_trainer/grpo.py \
--model-name Qwen/Qwen2.5-3B-Instruct \
--weight-bridge-mode lora_only \
--lora-r 32 \
--lora-alpha 64 \
--training-steps 500 \
--batch-size 2 \
--gradient-accumulation-steps 32 \
--lr 5e-5 \
--use-wandb \
--wandb-project gsm8k-lora
Production (Shared vLLM Mode)
# Maximum throughput setup
export LOGDIR=/tmp/atropos_bridge
export NUM_INFERENCE_NODES=0
python example_trainer/grpo.py \
--model-name Qwen/Qwen2.5-3B-Instruct \
--weight-bridge-mode shared_vllm \
--training-steps 1000 \
--batch-size 4 \
--gradient-accumulation-steps 16 \
--lr 1e-5 \
--use-wandb \
--wandb-project gsm8k-shared