README updates

This commit is contained in:
Jai Suphavadeeprasit 2026-01-19 21:34:15 -05:00
parent 5c0ac94095
commit ee23761709

View file

@ -188,7 +188,7 @@ You should see:
[vLLM Patch] ✓ Exported 339 params to vllm_bridge_config.json
```
#### Step 7: Start an Environment (THE EXAMPLE HERE IS GSM8K in this case)
#### Step 7: Start an Environment (GSM8K in this case)
```bash
python environments/gsm8k_server.py serve \
@ -404,6 +404,8 @@ CUDA_VISIBLE_DEVICES=0 LOGDIR=. python -u example_trainer/grpo.py \
**The Simple Approach**: Save full checkpoints, restart vLLM to load new weights.
> **Note**: In Legacy mode, the **trainer manages its own vLLM process**. Do NOT start vLLM separately - the trainer will automatically start, stop, and restart vLLM with updated checkpoints.
```
┌─────────────────────────────────────────────────────────────────────────────────────┐
│ LEGACY MODE - COMPLETE DATA FLOW │
@ -502,31 +504,113 @@ CUDA_VISIBLE_DEVICES=0 LOGDIR=. python -u example_trainer/grpo.py \
For simple setups or debugging. Saves checkpoints and restarts vLLM to load new weights.
**IMPORTANT**: In Legacy mode, the **trainer manages its own vLLM process**. Do NOT start vLLM separately - the trainer will start, stop, and restart vLLM automatically as needed.
```bash
python example_trainer/grpo.py \
# Step 1: Set environment
export LOGDIR=/tmp/atropos_test
mkdir -p $LOGDIR
# Step 2: Kill any existing processes
pkill -f "vllm_api_server" || true
pkill -f "gsm8k_server" || true
sleep 2
# Step 3: Start GSM8k environment (pointing to port 9001 where trainer will launch vLLM)
LOGDIR=$LOGDIR python -u environments/gsm8k_server.py serve \
--env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \
--env.use_wandb false \
--openai.model_name Qwen/Qwen2.5-3B-Instruct \
--openai.base_url http://localhost:9001/v1 \
--openai.server_type vllm \
> $LOGDIR/gsm8k_legacy.log 2>&1 &
sleep 5
# Step 4: Start trainer (it will launch vLLM automatically!)
CUDA_VISIBLE_DEVICES=0 python -u example_trainer/grpo.py \
--model-name Qwen/Qwen2.5-3B-Instruct \
--weight-bridge-mode none \
--vllm-port 9001 \
--training-steps 100 \
--vllm-restart-interval 10 \
--batch-size 2 \
--lr 1e-5
--lr 1e-5 \
--save-path $LOGDIR/checkpoints_legacy \
--benchmark \
2>&1 | tee $LOGDIR/trainer_legacy.log
```
**What happens:**
1. Trainer starts its own vLLM process on port 9001
2. Training proceeds, accumulating weight updates
3. Every `--vllm-restart-interval` steps, trainer:
- Saves a checkpoint to disk
- Kills the current vLLM process
- Starts a new vLLM process with the updated checkpoint
4. This continues until training completes
### LoRA Mode (Adapter Training)
Trains only adapter weights. Small checkpoints, lower memory.
Trains only adapter weights. Small checkpoints, lower memory. Requires vLLM to be started separately with `--enable-lora`.
```bash
python example_trainer/grpo.py \
# Step 1: Set environment
export LOGDIR=/tmp/atropos_test
mkdir -p $LOGDIR
# Step 2: Kill any existing processes
pkill -f "vllm_api_server" || true
pkill -f "gsm8k_server" || true
sleep 2
# Step 3: Start vLLM with LoRA support (use --enforce-eager to avoid Triton issues)
LOGDIR=$LOGDIR python -u example_trainer/vllm_api_server.py \
--model Qwen/Qwen2.5-3B-Instruct \
--port 9001 \
--dtype bfloat16 \
--gpu-memory-utilization 0.4 \
--enable-lora \
--max-lora-rank 32 \
--enforce-eager \
> $LOGDIR/vllm_lora.log 2>&1 &
echo "Waiting 60s for vLLM..."; sleep 60
# Verify vLLM is ready
curl -s http://localhost:9001/health && echo " vLLM is ready!"
# Step 4: Start GSM8k environment
LOGDIR=$LOGDIR python -u environments/gsm8k_server.py serve \
--env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \
--env.use_wandb false \
--openai.model_name Qwen/Qwen2.5-3B-Instruct \
--openai.base_url http://localhost:9001/v1 \
--openai.server_type vllm \
> $LOGDIR/gsm8k_lora.log 2>&1 &
sleep 10
# Step 5: Start trainer with LoRA (can use different GPU)
CUDA_VISIBLE_DEVICES=1 python -u example_trainer/grpo.py \
--model-name Qwen/Qwen2.5-3B-Instruct \
--weight-bridge-mode lora_only \
--vllm-port 9001 \
--lora-r 16 \
--lora-alpha 32 \
--training-steps 100 \
--batch-size 2 \
--lr 1e-4
--lr 1e-4 \
--save-path $LOGDIR/checkpoints_lora \
--benchmark \
2>&1 | tee $LOGDIR/trainer_lora.log
```
**What happens:**
1. vLLM runs with LoRA support enabled
2. Trainer loads base model + creates LoRA adapters
3. After each sync interval, trainer:
- Saves small LoRA adapter (~50MB)
- Hot-swaps adapter to vLLM via `/lora/load` endpoint
4. vLLM uses new adapter for next inference batch
---
## Configuration Reference
@ -607,13 +691,13 @@ The JSON file contains everything needed to reconstruct tensor references in ano
"model": "Qwen/Qwen2.5-3B-Instruct",
"tp_degree": 1,
"dp_shard_degree": 1,
"param_names": [
"model.embed_tokens.weight",
"model.layers.0.self_attn.qkv_proj.weight",
...
],
"param_mappings": {
"model.embed_tokens.weight": {
"vllm_name": "model.embed_tokens.weight",
@ -623,23 +707,23 @@ The JSON file contains everything needed to reconstruct tensor references in ano
},
...
},
"ipc_handles": {
"model.embed_tokens.weight": {
"device_index": 0,
"ipc_handle_b64": "AmPA0pN...",
"ipc_handle_b64": "AmPA0pN...",
"storage_size": 623902720,
"storage_offset": 0,
"ref_counter_handle_b64": "Y2JY...",
"ref_counter_offset": 0,
"event_handle_b64": "wRIs...",
"event_handle_b64": "wRIs...",
"event_sync_required": true,
"shape": [152064, 2048],
"dtype": "torch.bfloat16"
},
...
},
"shared_weights_enabled": true,
"single_copy_enabled": true,
"num_params": 255
@ -672,15 +756,15 @@ The JSON file contains everything needed to reconstruct tensor references in ano
for name, ipc_info in config["ipc_handles"].items():
# Decode IPC handle from base64
ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"])
# Reconstruct storage from IPC handle
storage = torch.UntypedStorage._new_shared_cuda(
device_index, ipc_handle, storage_size, ...
)
# Create tensor from shared storage
tensor = torch.tensor(storage).view(shape).to(dtype)
# Replace model parameter with shared tensor
model.get_parameter(name).data = tensor
```
@ -816,7 +900,7 @@ pkill -9 -u $USER -f "vllm|grpo|python|run-api"
## Feature Availability Matrix
### What's Available
### What's Available
| Feature | Status | Notes |
|---------|--------|-------|
@ -832,7 +916,7 @@ pkill -9 -u $USER -f "vllm|grpo|python|run-api"
| **Wandb Logging** | Working | Via `--use-wandb` flag |
| **Custom Environments** | Working | Extend `BaseEnv` class |
### What's NOT Available
### What's NOT Available
| Feature | Mode | Status | Reason / Workaround |
|---------|------|--------|---------------------|
@ -858,7 +942,7 @@ pkill -9 -u $USER -f "vllm|grpo|python|run-api"
| **LoRA** | Supported | Via vLLM | Multiple Trainers |
| **Legacy** | Supported | Via vLLM | Multiple Trainers |
> **Key Point**: The multi-GPU limitation is **ONLY for single-copy mode** due to CUDA IPC constraints.
> **Key Point**: The multi-GPU limitation is **ONLY for single-copy mode** due to CUDA IPC constraints.
> LoRA and Legacy modes work with standard vLLM which fully supports tensor parallelism.
#### Pipeline Parallel (PP)
@ -956,7 +1040,7 @@ CUDA_VISIBLE_DEVICES=5 python -u example_trainer/grpo.py \
## Future Work
### High Priority
### High Priority
| Feature | Description |
|---------|-------------|
@ -964,7 +1048,7 @@ CUDA_VISIBLE_DEVICES=5 python -u example_trainer/grpo.py \
| **Automatic Server Type Detection** | Auto-detect correct `server_type` for environments |
| **Checkpoint Resume** | Resume training from checkpoints seamlessly |
### Medium Priority
### Medium Priority
| Feature | Description | Difficulty |
|---------|-------------|------------|