mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
README updates
This commit is contained in:
parent
5c0ac94095
commit
ee23761709
1 changed files with 104 additions and 20 deletions
|
|
@ -188,7 +188,7 @@ You should see:
|
|||
[vLLM Patch] ✓ Exported 339 params to vllm_bridge_config.json
|
||||
```
|
||||
|
||||
#### Step 7: Start an Environment (THE EXAMPLE HERE IS GSM8K in this case)
|
||||
#### Step 7: Start an Environment (GSM8K in this case)
|
||||
|
||||
```bash
|
||||
python environments/gsm8k_server.py serve \
|
||||
|
|
@ -404,6 +404,8 @@ CUDA_VISIBLE_DEVICES=0 LOGDIR=. python -u example_trainer/grpo.py \
|
|||
|
||||
**The Simple Approach**: Save full checkpoints, restart vLLM to load new weights.
|
||||
|
||||
> **Note**: In Legacy mode, the **trainer manages its own vLLM process**. Do NOT start vLLM separately - the trainer will automatically start, stop, and restart vLLM with updated checkpoints.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────────────────┐
|
||||
│ LEGACY MODE - COMPLETE DATA FLOW │
|
||||
|
|
@ -502,31 +504,113 @@ CUDA_VISIBLE_DEVICES=0 LOGDIR=. python -u example_trainer/grpo.py \
|
|||
|
||||
For simple setups or debugging. Saves checkpoints and restarts vLLM to load new weights.
|
||||
|
||||
**IMPORTANT**: In Legacy mode, the **trainer manages its own vLLM process**. Do NOT start vLLM separately - the trainer will start, stop, and restart vLLM automatically as needed.
|
||||
|
||||
```bash
|
||||
python example_trainer/grpo.py \
|
||||
# Step 1: Set environment
|
||||
export LOGDIR=/tmp/atropos_test
|
||||
mkdir -p $LOGDIR
|
||||
|
||||
# Step 2: Kill any existing processes
|
||||
pkill -f "vllm_api_server" || true
|
||||
pkill -f "gsm8k_server" || true
|
||||
sleep 2
|
||||
|
||||
# Step 3: Start GSM8k environment (pointing to port 9001 where trainer will launch vLLM)
|
||||
LOGDIR=$LOGDIR python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \
|
||||
--env.use_wandb false \
|
||||
--openai.model_name Qwen/Qwen2.5-3B-Instruct \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.server_type vllm \
|
||||
> $LOGDIR/gsm8k_legacy.log 2>&1 &
|
||||
sleep 5
|
||||
|
||||
# Step 4: Start trainer (it will launch vLLM automatically!)
|
||||
CUDA_VISIBLE_DEVICES=0 python -u example_trainer/grpo.py \
|
||||
--model-name Qwen/Qwen2.5-3B-Instruct \
|
||||
--weight-bridge-mode none \
|
||||
--vllm-port 9001 \
|
||||
--training-steps 100 \
|
||||
--vllm-restart-interval 10 \
|
||||
--batch-size 2 \
|
||||
--lr 1e-5
|
||||
--lr 1e-5 \
|
||||
--save-path $LOGDIR/checkpoints_legacy \
|
||||
--benchmark \
|
||||
2>&1 | tee $LOGDIR/trainer_legacy.log
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. Trainer starts its own vLLM process on port 9001
|
||||
2. Training proceeds, accumulating weight updates
|
||||
3. Every `--vllm-restart-interval` steps, trainer:
|
||||
- Saves a checkpoint to disk
|
||||
- Kills the current vLLM process
|
||||
- Starts a new vLLM process with the updated checkpoint
|
||||
4. This continues until training completes
|
||||
|
||||
### LoRA Mode (Adapter Training)
|
||||
|
||||
Trains only adapter weights. Small checkpoints, lower memory.
|
||||
Trains only adapter weights. Small checkpoints, lower memory. Requires vLLM to be started separately with `--enable-lora`.
|
||||
|
||||
```bash
|
||||
python example_trainer/grpo.py \
|
||||
# Step 1: Set environment
|
||||
export LOGDIR=/tmp/atropos_test
|
||||
mkdir -p $LOGDIR
|
||||
|
||||
# Step 2: Kill any existing processes
|
||||
pkill -f "vllm_api_server" || true
|
||||
pkill -f "gsm8k_server" || true
|
||||
sleep 2
|
||||
|
||||
# Step 3: Start vLLM with LoRA support (use --enforce-eager to avoid Triton issues)
|
||||
LOGDIR=$LOGDIR python -u example_trainer/vllm_api_server.py \
|
||||
--model Qwen/Qwen2.5-3B-Instruct \
|
||||
--port 9001 \
|
||||
--dtype bfloat16 \
|
||||
--gpu-memory-utilization 0.4 \
|
||||
--enable-lora \
|
||||
--max-lora-rank 32 \
|
||||
--enforce-eager \
|
||||
> $LOGDIR/vllm_lora.log 2>&1 &
|
||||
echo "Waiting 60s for vLLM..."; sleep 60
|
||||
|
||||
# Verify vLLM is ready
|
||||
curl -s http://localhost:9001/health && echo " vLLM is ready!"
|
||||
|
||||
# Step 4: Start GSM8k environment
|
||||
LOGDIR=$LOGDIR python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \
|
||||
--env.use_wandb false \
|
||||
--openai.model_name Qwen/Qwen2.5-3B-Instruct \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.server_type vllm \
|
||||
> $LOGDIR/gsm8k_lora.log 2>&1 &
|
||||
sleep 10
|
||||
|
||||
# Step 5: Start trainer with LoRA (can use different GPU)
|
||||
CUDA_VISIBLE_DEVICES=1 python -u example_trainer/grpo.py \
|
||||
--model-name Qwen/Qwen2.5-3B-Instruct \
|
||||
--weight-bridge-mode lora_only \
|
||||
--vllm-port 9001 \
|
||||
--lora-r 16 \
|
||||
--lora-alpha 32 \
|
||||
--training-steps 100 \
|
||||
--batch-size 2 \
|
||||
--lr 1e-4
|
||||
--lr 1e-4 \
|
||||
--save-path $LOGDIR/checkpoints_lora \
|
||||
--benchmark \
|
||||
2>&1 | tee $LOGDIR/trainer_lora.log
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. vLLM runs with LoRA support enabled
|
||||
2. Trainer loads base model + creates LoRA adapters
|
||||
3. After each sync interval, trainer:
|
||||
- Saves small LoRA adapter (~50MB)
|
||||
- Hot-swaps adapter to vLLM via `/lora/load` endpoint
|
||||
4. vLLM uses new adapter for next inference batch
|
||||
|
||||
---
|
||||
|
||||
## Configuration Reference
|
||||
|
|
@ -607,13 +691,13 @@ The JSON file contains everything needed to reconstruct tensor references in ano
|
|||
"model": "Qwen/Qwen2.5-3B-Instruct",
|
||||
"tp_degree": 1,
|
||||
"dp_shard_degree": 1,
|
||||
|
||||
|
||||
"param_names": [
|
||||
"model.embed_tokens.weight",
|
||||
"model.layers.0.self_attn.qkv_proj.weight",
|
||||
...
|
||||
],
|
||||
|
||||
|
||||
"param_mappings": {
|
||||
"model.embed_tokens.weight": {
|
||||
"vllm_name": "model.embed_tokens.weight",
|
||||
|
|
@ -623,23 +707,23 @@ The JSON file contains everything needed to reconstruct tensor references in ano
|
|||
},
|
||||
...
|
||||
},
|
||||
|
||||
|
||||
"ipc_handles": {
|
||||
"model.embed_tokens.weight": {
|
||||
"device_index": 0,
|
||||
"ipc_handle_b64": "AmPA0pN...",
|
||||
"ipc_handle_b64": "AmPA0pN...",
|
||||
"storage_size": 623902720,
|
||||
"storage_offset": 0,
|
||||
"ref_counter_handle_b64": "Y2JY...",
|
||||
"ref_counter_offset": 0,
|
||||
"event_handle_b64": "wRIs...",
|
||||
"event_handle_b64": "wRIs...",
|
||||
"event_sync_required": true,
|
||||
"shape": [152064, 2048],
|
||||
"dtype": "torch.bfloat16"
|
||||
},
|
||||
...
|
||||
},
|
||||
|
||||
|
||||
"shared_weights_enabled": true,
|
||||
"single_copy_enabled": true,
|
||||
"num_params": 255
|
||||
|
|
@ -672,15 +756,15 @@ The JSON file contains everything needed to reconstruct tensor references in ano
|
|||
for name, ipc_info in config["ipc_handles"].items():
|
||||
# Decode IPC handle from base64
|
||||
ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"])
|
||||
|
||||
|
||||
# Reconstruct storage from IPC handle
|
||||
storage = torch.UntypedStorage._new_shared_cuda(
|
||||
device_index, ipc_handle, storage_size, ...
|
||||
)
|
||||
|
||||
|
||||
# Create tensor from shared storage
|
||||
tensor = torch.tensor(storage).view(shape).to(dtype)
|
||||
|
||||
|
||||
# Replace model parameter with shared tensor
|
||||
model.get_parameter(name).data = tensor
|
||||
```
|
||||
|
|
@ -816,7 +900,7 @@ pkill -9 -u $USER -f "vllm|grpo|python|run-api"
|
|||
|
||||
## Feature Availability Matrix
|
||||
|
||||
### What's Available
|
||||
### What's Available
|
||||
|
||||
| Feature | Status | Notes |
|
||||
|---------|--------|-------|
|
||||
|
|
@ -832,7 +916,7 @@ pkill -9 -u $USER -f "vllm|grpo|python|run-api"
|
|||
| **Wandb Logging** | Working | Via `--use-wandb` flag |
|
||||
| **Custom Environments** | Working | Extend `BaseEnv` class |
|
||||
|
||||
### What's NOT Available
|
||||
### What's NOT Available
|
||||
|
||||
| Feature | Mode | Status | Reason / Workaround |
|
||||
|---------|------|--------|---------------------|
|
||||
|
|
@ -858,7 +942,7 @@ pkill -9 -u $USER -f "vllm|grpo|python|run-api"
|
|||
| **LoRA** | Supported | Via vLLM | Multiple Trainers |
|
||||
| **Legacy** | Supported | Via vLLM | Multiple Trainers |
|
||||
|
||||
> **Key Point**: The multi-GPU limitation is **ONLY for single-copy mode** due to CUDA IPC constraints.
|
||||
> **Key Point**: The multi-GPU limitation is **ONLY for single-copy mode** due to CUDA IPC constraints.
|
||||
> LoRA and Legacy modes work with standard vLLM which fully supports tensor parallelism.
|
||||
|
||||
#### Pipeline Parallel (PP)
|
||||
|
|
@ -956,7 +1040,7 @@ CUDA_VISIBLE_DEVICES=5 python -u example_trainer/grpo.py \
|
|||
|
||||
## Future Work
|
||||
|
||||
### High Priority
|
||||
### High Priority
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
|
|
@ -964,7 +1048,7 @@ CUDA_VISIBLE_DEVICES=5 python -u example_trainer/grpo.py \
|
|||
| **Automatic Server Type Detection** | Auto-detect correct `server_type` for environments |
|
||||
| **Checkpoint Resume** | Resume training from checkpoints seamlessly |
|
||||
|
||||
### Medium Priority
|
||||
### Medium Priority
|
||||
|
||||
| Feature | Description | Difficulty |
|
||||
|---------|-------------|------------|
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue