mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-30 17:40:36 +00:00
initial commit
This commit is contained in:
parent
407a22ba12
commit
3ed23058c3
5 changed files with 2452 additions and 399 deletions
|
|
@ -2,78 +2,418 @@
|
|||
|
||||
This directory contains an example script (`grpo.py`) demonstrating how to integrate a custom training loop with the Atropos API for reinforcement learning using the GRPO (Group Relative Policy Optimization) algorithm.
|
||||
|
||||
**Note: Example trainer does not support multimodal training out of the box. As other trainers add support for Atropos, we will list them in the main readme, some of which may support multimodal RL - please check the main repo readme for any updates.**
|
||||
## Training Modes
|
||||
|
||||
This example uses `vLLM` for efficient inference during the (simulated) data generation phase and `transformers` for the training phase.
|
||||
The trainer supports three weight synchronization modes:
|
||||
|
||||
**Note:** This script is intended as a *reference example* for API integration and basic training setup. It is not optimized for large-scale, efficient training.
|
||||
| Mode | Description | Sync Latency | Best For |
|
||||
|------|-------------|--------------|----------|
|
||||
| **Legacy** (`none`) | Save checkpoints, restart vLLM | ~30-60 seconds | Simple setups, debugging |
|
||||
| **Shared vLLM** (`shared_vllm`) | Direct shared memory updates | ~0 ms | Production, maximum throughput |
|
||||
| **LoRA** (`lora_only`) | Train adapters, hot-swap | ~1-5 seconds | Memory-constrained, fast iteration |
|
||||
|
||||
### Custom vLLM Server
|
||||
---
|
||||
|
||||
The `vllm_api_server.py` file in this directory provides a customized vLLM API server implementation based on vLLM's native API. This server exposes enhanced endpoints for token and logprob tracking. The `VLLMServer` class in `atroposlib/envs/server_handling/vllm_server.py` can connect to this server for direct access to vLLM's `/generate` endpoint with full token-level logprobs.
|
||||
## Quick Start with GSM8k
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. **Python:** Python 3.10 or higher is recommended.
|
||||
2. **Atropos API Server:** The Atropos API server must be running and accessible (defaults to `http://localhost:8000` in the script).
|
||||
3. **Python Packages:** You need to install the required Python libraries:
|
||||
* `torch` (with CUDA support recommended)
|
||||
* `transformers`
|
||||
* `vllm`
|
||||
* `pydantic`
|
||||
* `numpy`
|
||||
* `requests`
|
||||
* `tenacity`
|
||||
* `wandb` (optional, for logging)
|
||||
|
||||
## Setup
|
||||
|
||||
1. **Clone the Repository:** Ensure you have the repository containing this example.
|
||||
2. **Install Dependencies:** `pip install -r requirements.txt`
|
||||
3. **Ensure Atropos API is Running:** `run-api` in a new window
|
||||
4. **Run an env:** `python environments/gsm8k_server.py serve --slurm False`
|
||||
|
||||
## Configuration
|
||||
|
||||
The training configuration is managed within the `grpo.py` script using the `TrainingConfig` Pydantic model (found near the top of the file).
|
||||
|
||||
Key parameters you might want to adjust include:
|
||||
|
||||
* `model_name`: The Hugging Face model identifier to use for training (e.g., `"gpt2"`, `"Qwen/Qwen2.5-1.5B-Instruct"`).
|
||||
* `training_steps`: The total number of optimization steps to perform.
|
||||
* `batch_size` / `gradient_accumulation_steps`: Control the effective batch size.
|
||||
* `lr`: Learning rate.
|
||||
* `save_path`: Directory where model checkpoints will be saved.
|
||||
* `vllm_port`: The port used by the vLLM server instance launched by this script.
|
||||
* `vllm_restart_interval`: How often (in steps) to save a checkpoint and restart the vLLM server with the new weights.
|
||||
* `use_wandb`: Set to `True` to enable logging to Weights & Biases.
|
||||
* `wandb_project`: Your W&B project name (required if `use_wandb=True`).
|
||||
* `wandb_group`: Optional W&B group name.
|
||||
|
||||
**API Endpoints:** The script currently assumes the Atropos API is available at `http://localhost:8000/register` and `http://localhost:8000/batch`. If your API runs elsewhere, you'll need to modify the `register_trainer` and `get_batch` functions accordingly.
|
||||
|
||||
## Running the Example
|
||||
|
||||
Once the prerequisites are met and configuration is set:
|
||||
|
||||
1. Navigate to the root directory of the project in your terminal.
|
||||
2. Run the script:
|
||||
|
||||
```bash
|
||||
python example_trainer/grpo.py
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
* **Logs:** Training progress, loss, logp, and vLLM status will be printed to the console.
|
||||
* **Checkpoints:** Model checkpoints will be saved periodically in the directory specified by `save_path` (default: `./trained_model_checkpoints`). A `final_model` directory will be created upon completion.
|
||||
* **WandB:** If `use_wandb` is `True`, logs will be sent to Weights & Biases. A link to the run page will be printed in the console.
|
||||
* `temp.json`: Contains the raw data from the last fetched batch (used for debugging/manual inspection).
|
||||
### Prerequisites
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install -r example_trainer/requirements.txt
|
||||
|
||||
# Run the trainer directly (basic test)
|
||||
python example_trainer/grpo.py
|
||||
# Install GSM8k environment dependencies
|
||||
pip install datasets latex2sympy2_extended math_verify
|
||||
```
|
||||
|
||||
### Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Training Setup │
|
||||
│ │
|
||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
|
||||
│ │ GSM8k Env │───▶│ Atropos API │◀───│ GRPO Trainer │ │
|
||||
│ │ (problems) │ │ (batching) │ │ (optimization) │ │
|
||||
│ └─────────────┘ └─────────────┘ └─────────────────────┘ │
|
||||
│ │ │ │
|
||||
│ │ │ │
|
||||
│ ▼ ▼ │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ vLLM Inference Server │ │
|
||||
│ │ (generates rollouts for scoring) │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Mode 1: Legacy (Checkpoint + Restart)
|
||||
|
||||
This is the simplest mode. The trainer periodically saves checkpoints and restarts vLLM.
|
||||
|
||||
### Step-by-Step Guide
|
||||
|
||||
**Terminal 1: Start the Atropos API**
|
||||
```bash
|
||||
cd atropos
|
||||
run-api
|
||||
```
|
||||
|
||||
**Terminal 2: Start the GSM8k Environment**
|
||||
```bash
|
||||
cd atropos
|
||||
python environments/gsm8k_server.py serve --slurm False
|
||||
```
|
||||
|
||||
**Terminal 3: Start the GRPO Trainer**
|
||||
```bash
|
||||
cd atropos
|
||||
python example_trainer/grpo.py \
|
||||
--model-name Qwen/Qwen2.5-3B-Instruct \
|
||||
--weight-bridge-mode none \
|
||||
--training-steps 100 \
|
||||
--vllm-restart-interval 10 \
|
||||
--batch-size 2 \
|
||||
--gradient-accumulation-steps 16 \
|
||||
--lr 1e-5 \
|
||||
--use-wandb \
|
||||
--wandb-project gsm8k-grpo
|
||||
```
|
||||
|
||||
### What Happens
|
||||
|
||||
1. Trainer loads `Qwen/Qwen2.5-3B-Instruct` into GPU memory
|
||||
2. Trainer launches vLLM server on port 9001
|
||||
3. GSM8k env sends problems → vLLM generates solutions → scores sent to API
|
||||
4. Trainer fetches scored batches from API, computes GRPO loss, updates weights
|
||||
5. Every 10 steps: save checkpoint → kill vLLM → restart vLLM with new weights
|
||||
6. Repeat until done
|
||||
|
||||
### Pros & Cons
|
||||
|
||||
+ Simple, works out of the box
|
||||
+ Easy to debug
|
||||
- 30-60 second sync latency per restart
|
||||
- 2x GPU memory (trainer + vLLM both load model)
|
||||
|
||||
---
|
||||
|
||||
## Mode 2: Shared vLLM Bridge (In-Place Updates)
|
||||
|
||||
This mode shares GPU tensors between trainer and vLLM. Updates happen instantly.
|
||||
|
||||
### Step-by-Step Guide
|
||||
|
||||
**Terminal 1: Start the Atropos API**
|
||||
```bash
|
||||
cd atropos
|
||||
run-api
|
||||
```
|
||||
|
||||
**Terminal 2: Set up environment variables and start vLLM with bridge support**
|
||||
```bash
|
||||
cd atropos
|
||||
export LOGDIR=/tmp/atropos_bridge
|
||||
export NUM_INFERENCE_NODES=0 # Single-node local mode
|
||||
export MASTER_ADDR=localhost
|
||||
export MASTER_PORT=26756
|
||||
|
||||
mkdir -p $LOGDIR
|
||||
|
||||
# Start the custom vLLM server with bridge endpoints
|
||||
python example_trainer/vllm_api_server.py \
|
||||
--model Qwen/Qwen2.5-3B-Instruct \
|
||||
--port 9001 \
|
||||
--gpu-memory-utilization 0.45
|
||||
```
|
||||
|
||||
**Terminal 3: Start the GSM8k Environment**
|
||||
```bash
|
||||
cd atropos
|
||||
python environments/gsm8k_server.py serve --slurm False
|
||||
```
|
||||
|
||||
**Terminal 4: Start the GRPO Trainer in shared mode**
|
||||
```bash
|
||||
cd atropos
|
||||
export LOGDIR=/tmp/atropos_bridge
|
||||
export NUM_INFERENCE_NODES=0
|
||||
export MASTER_ADDR=localhost
|
||||
export MASTER_PORT=26756
|
||||
|
||||
python example_trainer/grpo.py \
|
||||
--model-name Qwen/Qwen2.5-3B-Instruct \
|
||||
--weight-bridge-mode shared_vllm \
|
||||
--trainer-rank 0 \
|
||||
--world-size 1 \
|
||||
--num-inference-nodes 0 \
|
||||
--training-steps 100 \
|
||||
--batch-size 2 \
|
||||
--gradient-accumulation-steps 16 \
|
||||
--lr 1e-5 \
|
||||
--use-wandb \
|
||||
--wandb-project gsm8k-grpo-shared
|
||||
```
|
||||
|
||||
### What Happens
|
||||
|
||||
1. vLLM server starts, writes parameter mapping to `$LOGDIR/vllm_bridge_config.json`
|
||||
2. Trainer reads mapping, joins NCCL process group with vLLM
|
||||
3. Trainer's model parameters point to vLLM's GPU tensors (shared memory)
|
||||
4. Training loop:
|
||||
- Forward pass uses shared weights
|
||||
- `optimizer.step()` modifies shared tensors in-place
|
||||
- `bridge.notify_update()` signals vLLM (optional coordination)
|
||||
- vLLM immediately uses new weights for next inference
|
||||
5. No restarts needed!
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description | Example |
|
||||
|----------|-------------|---------|
|
||||
| `LOGDIR` | Directory for bridge coordination files | `/tmp/atropos_bridge` |
|
||||
| `NUM_INFERENCE_NODES` | Number of vLLM nodes (0 = local) | `0` |
|
||||
| `MASTER_ADDR` | Rendezvous address | `localhost` |
|
||||
| `MASTER_PORT` | Rendezvous port | `26756` |
|
||||
|
||||
### Pros & Cons
|
||||
|
||||
+ ~0ms sync latency (instant updates)
|
||||
+ 1x GPU memory (shared tensors)
|
||||
+ Maximum training throughput
|
||||
- More complex setup
|
||||
- Requires compatible vLLM version
|
||||
|
||||
---
|
||||
|
||||
## Mode 3: LoRA Adapters (Hot-Swap)
|
||||
|
||||
This mode trains only LoRA adapter weights. Much smaller checkpoints, faster iteration.
|
||||
|
||||
### Step-by-Step Guide
|
||||
|
||||
**Terminal 1: Start the Atropos API**
|
||||
```bash
|
||||
cd atropos
|
||||
run-api
|
||||
```
|
||||
|
||||
**Terminal 2: Start the GSM8k Environment**
|
||||
```bash
|
||||
cd atropos
|
||||
python environments/gsm8k_server.py serve --slurm False
|
||||
```
|
||||
|
||||
**Terminal 3: Start the GRPO Trainer in LoRA mode**
|
||||
```bash
|
||||
cd atropos
|
||||
python example_trainer/grpo.py \
|
||||
--model-name Qwen/Qwen2.5-3B-Instruct \
|
||||
--weight-bridge-mode lora_only \
|
||||
--lora-r 16 \
|
||||
--lora-alpha 32 \
|
||||
--lora-dropout 0.05 \
|
||||
--lora-target-modules q_proj v_proj \
|
||||
--training-steps 100 \
|
||||
--vllm-restart-interval 20 \
|
||||
--batch-size 2 \
|
||||
--gradient-accumulation-steps 16 \
|
||||
--lr 1e-4 \
|
||||
--use-wandb \
|
||||
--wandb-project gsm8k-grpo-lora
|
||||
```
|
||||
|
||||
### What Happens
|
||||
|
||||
1. Trainer loads base model, wraps with LoRA adapters (PEFT)
|
||||
2. Only adapter parameters are trainable (~0.1% of total)
|
||||
3. Training loop updates adapter weights only
|
||||
4. Every N steps: save adapter checkpoint (small, ~10-50MB)
|
||||
5. vLLM can hot-swap adapters via `/lora/load` endpoint
|
||||
|
||||
### LoRA Configuration
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--lora-r` | 16 | Rank of low-rank matrices |
|
||||
| `--lora-alpha` | 32 | Scaling factor (typically 2x rank) |
|
||||
| `--lora-dropout` | 0.05 | Dropout for regularization |
|
||||
| `--lora-target-modules` | `q_proj v_proj` | Which layers to adapt |
|
||||
|
||||
### Common Target Module Combinations
|
||||
|
||||
```bash
|
||||
# Minimal (fastest training)
|
||||
--lora-target-modules q_proj v_proj
|
||||
|
||||
# Attention only
|
||||
--lora-target-modules q_proj k_proj v_proj o_proj
|
||||
|
||||
# Full (most expressive)
|
||||
--lora-target-modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj
|
||||
```
|
||||
|
||||
### Pros & Cons
|
||||
|
||||
+ Much faster training (fewer parameters)
|
||||
+ Tiny checkpoints (~10-50MB vs ~6GB)
|
||||
+ Can hot-swap adapters without full restart
|
||||
+ Lower GPU memory (base model frozen)
|
||||
- Less expressive than full fine-tuning
|
||||
- May need higher learning rate
|
||||
|
||||
---
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### All CLI Options
|
||||
|
||||
```bash
|
||||
python example_trainer/grpo.py --help
|
||||
```
|
||||
|
||||
### Core Training Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--model-name` | (required) | HuggingFace model ID |
|
||||
| `--lr` | `1e-5` | Learning rate |
|
||||
| `--training-steps` | `10` | Total optimization steps |
|
||||
| `--batch-size` | `2` | Micro-batch size |
|
||||
| `--gradient-accumulation-steps` | `32` | Gradient accumulation |
|
||||
| `--seq-len` | `2048` | Max sequence length |
|
||||
| `--save-path` | `trained_model_checkpoints` | Checkpoint directory |
|
||||
|
||||
### vLLM Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--vllm-port` | `9001` | vLLM server port |
|
||||
| `--vllm-restart-interval` | `3` | Steps between syncs |
|
||||
|
||||
### Weight Bridge Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--weight-bridge-mode` | `none` | `none`, `shared_vllm`, or `lora_only` |
|
||||
| `--trainer-rank` | `0` | Distributed rank |
|
||||
| `--world-size` | `1` | Total processes |
|
||||
| `--init-method` | `env://` | PyTorch distributed init |
|
||||
| `--num-inference-nodes` | `0` | Number of vLLM nodes |
|
||||
|
||||
### Logging Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--use-wandb` | `False` | Enable W&B logging |
|
||||
| `--wandb-project` | `None` | W&B project name |
|
||||
| `--wandb-group` | `None` | W&B group name |
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "CUDA out of memory"
|
||||
|
||||
Try reducing:
|
||||
```bash
|
||||
--batch-size 1 \
|
||||
--gradient-accumulation-steps 64 \
|
||||
--seq-len 1024
|
||||
```
|
||||
|
||||
Or use LoRA mode which uses less memory.
|
||||
|
||||
### "Connection refused" to Atropos API
|
||||
|
||||
Make sure the API is running:
|
||||
```bash
|
||||
run-api # In a separate terminal
|
||||
```
|
||||
|
||||
### vLLM fails to start
|
||||
|
||||
Check if port 9001 is in use:
|
||||
```bash
|
||||
lsof -i :9001
|
||||
```
|
||||
|
||||
Kill existing processes or use a different port:
|
||||
```bash
|
||||
--vllm-port 9002
|
||||
```
|
||||
|
||||
### Bridge mode: "Parameter mapping file not found"
|
||||
|
||||
Ensure `$LOGDIR` is set and vLLM server is running:
|
||||
```bash
|
||||
export LOGDIR=/tmp/atropos_bridge
|
||||
ls $LOGDIR/vllm_bridge_config.json
|
||||
```
|
||||
|
||||
### LoRA mode: "PEFT library not available"
|
||||
|
||||
Install PEFT:
|
||||
```bash
|
||||
pip install peft
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Files in This Directory
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `grpo.py` | Main trainer script with all modes |
|
||||
| `vllm_api_server.py` | Custom vLLM server with bridge endpoints |
|
||||
| `vllm_weight_bridge.py` | Shared memory bridge implementation |
|
||||
| `requirements.txt` | Python dependencies |
|
||||
| `README.md` | This documentation |
|
||||
|
||||
---
|
||||
|
||||
## Example Runs
|
||||
|
||||
### Quick Test (Legacy Mode)
|
||||
```bash
|
||||
# Minimal test to verify setup works
|
||||
python example_trainer/grpo.py \
|
||||
--model-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--training-steps 5 \
|
||||
--batch-size 1 \
|
||||
--gradient-accumulation-steps 4
|
||||
```
|
||||
|
||||
### Full GSM8k Training (LoRA Mode)
|
||||
```bash
|
||||
# Recommended for single-GPU training
|
||||
python example_trainer/grpo.py \
|
||||
--model-name Qwen/Qwen2.5-3B-Instruct \
|
||||
--weight-bridge-mode lora_only \
|
||||
--lora-r 32 \
|
||||
--lora-alpha 64 \
|
||||
--training-steps 500 \
|
||||
--batch-size 2 \
|
||||
--gradient-accumulation-steps 32 \
|
||||
--lr 5e-5 \
|
||||
--use-wandb \
|
||||
--wandb-project gsm8k-lora
|
||||
```
|
||||
|
||||
### Production (Shared vLLM Mode)
|
||||
```bash
|
||||
# Maximum throughput setup
|
||||
export LOGDIR=/tmp/atropos_bridge
|
||||
export NUM_INFERENCE_NODES=0
|
||||
|
||||
python example_trainer/grpo.py \
|
||||
--model-name Qwen/Qwen2.5-3B-Instruct \
|
||||
--weight-bridge-mode shared_vllm \
|
||||
--training-steps 1000 \
|
||||
--batch-size 4 \
|
||||
--gradient-accumulation-steps 16 \
|
||||
--lr 1e-5 \
|
||||
--use-wandb \
|
||||
--wandb-project gsm8k-shared
|
||||
```
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue