mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
stuff
This commit is contained in:
parent
f5c847d39d
commit
a7bdc0270d
2 changed files with 295 additions and 5 deletions
|
|
@ -371,6 +371,260 @@ pip install peft
|
|||
|
||||
---
|
||||
|
||||
## Checkpoint Locations
|
||||
|
||||
### Where Are Trained Models Saved?
|
||||
|
||||
| Mode | Location | Contents |
|
||||
|------|----------|----------|
|
||||
| **Legacy** | `trained_model_checkpoints/step_N/` | Full model + tokenizer |
|
||||
| **Legacy** | `trained_model_checkpoints/final_model/` | Final checkpoint |
|
||||
| **Shared vLLM** | `trained_model_checkpoints/step_N/` | Full model + tokenizer |
|
||||
| **LoRA** | `trained_model_checkpoints/adapter_step_N/` | LoRA adapters only (~10-50MB) |
|
||||
| **LoRA** | `trained_model_checkpoints/final_adapter/` | Final adapter |
|
||||
|
||||
### Customizing Save Path
|
||||
|
||||
```bash
|
||||
python example_trainer/grpo.py \
|
||||
--save-path /path/to/my/checkpoints \
|
||||
...
|
||||
```
|
||||
|
||||
### Loading Checkpoints for Inference
|
||||
|
||||
```python
|
||||
# Full model (Legacy/Shared modes)
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained("trained_model_checkpoints/final_model")
|
||||
tokenizer = AutoTokenizer.from_pretrained("trained_model_checkpoints/final_model")
|
||||
|
||||
# LoRA adapter
|
||||
from peft import PeftModel, PeftConfig
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
|
||||
model = PeftModel.from_pretrained(base_model, "trained_model_checkpoints/final_adapter")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## vLLM Server Requirements
|
||||
|
||||
When using `--openai.server_type vllm` or the shared_vllm bridge, your vLLM server must expose these endpoints:
|
||||
|
||||
### Required Endpoints
|
||||
|
||||
| Endpoint | Method | Purpose | Used By |
|
||||
|----------|--------|---------|---------|
|
||||
| `/health` | GET | Health check | All modes |
|
||||
| `/generate` | POST | Native generation with token IDs + logprobs | VLLMServer class |
|
||||
|
||||
### Required `/generate` Request Format
|
||||
|
||||
The vLLM server must handle **both** prompt formats:
|
||||
|
||||
```json
|
||||
// String prompt (simple)
|
||||
{
|
||||
"prompt": "Hello, world!",
|
||||
"max_tokens": 100,
|
||||
"temperature": 1.0,
|
||||
"logprobs": 1
|
||||
}
|
||||
|
||||
// Token ID prompt (used by atroposlib)
|
||||
{
|
||||
"prompt": {"prompt_token_ids": [1, 2, 3, 4, 5]},
|
||||
"max_tokens": 100,
|
||||
"temperature": 1.0,
|
||||
"logprobs": 1
|
||||
}
|
||||
```
|
||||
|
||||
### Required `/generate` Response Format
|
||||
|
||||
```json
|
||||
{
|
||||
"text": ["generated text here"],
|
||||
"prompt": "original prompt",
|
||||
"finish_reasons": ["stop"],
|
||||
"logprobs": [
|
||||
[
|
||||
[{"12345": -0.5}],
|
||||
[{"67890": -1.2}]
|
||||
]
|
||||
],
|
||||
"prompt_token_ids": [1, 2, 3, 4, 5],
|
||||
"token_ids": [[12345, 67890, ...]]
|
||||
}
|
||||
```
|
||||
|
||||
The `logprobs` field format: `List[List[List[Dict[token_id, logprob]]]]`
|
||||
- Outer list: per completion (n samples)
|
||||
- Middle list: per token in completion
|
||||
- Inner list: contains single dict `{token_id: logprob}`
|
||||
|
||||
### Optional Bridge Endpoints (for shared_vllm mode)
|
||||
|
||||
| Endpoint | Method | Purpose |
|
||||
|----------|--------|---------|
|
||||
| `/bridge/info` | GET | Get bridge status |
|
||||
| `/bridge/notify_update` | POST | Receive weight update notifications |
|
||||
| `/bridge/state_dict_info` | GET | Get model parameter mappings |
|
||||
|
||||
### Optional LoRA Endpoints (for lora_only mode)
|
||||
|
||||
| Endpoint | Method | Purpose |
|
||||
|----------|--------|---------|
|
||||
| `/lora/status` | GET | Get active LoRA adapter |
|
||||
| `/lora/load` | POST | Load new LoRA adapter |
|
||||
| `/lora/unload` | POST | Unload current adapter |
|
||||
|
||||
### Using Standard vLLM vs Custom Server
|
||||
|
||||
| Server | Supports `/generate` with logprobs | Supports bridge | Supports LoRA hot-swap |
|
||||
|--------|-----------------------------------|-----------------|------------------------|
|
||||
| `vllm serve ...` | ❌ No | ❌ No | ❌ No |
|
||||
| `vllm_api_server.py` | ✅ Yes | ✅ Yes | ✅ Yes |
|
||||
|
||||
**Use `example_trainer/vllm_api_server.py` for full feature support.**
|
||||
|
||||
---
|
||||
|
||||
## Benchmarking Speed & Memory
|
||||
|
||||
### Memory Usage Comparison
|
||||
|
||||
```bash
|
||||
# Run this during training to monitor GPU memory
|
||||
watch -n 1 nvidia-smi
|
||||
```
|
||||
|
||||
**Expected Memory Usage (Qwen2.5-3B-Instruct):**
|
||||
|
||||
| Mode | Trainer GPU | vLLM GPU | Total |
|
||||
|------|------------|----------|-------|
|
||||
| **Legacy** | ~8GB | ~8GB | ~16GB (2x model) |
|
||||
| **Shared vLLM** | ~8GB (shared) | ~8GB (shared) | ~8GB (1x model) |
|
||||
| **LoRA** | ~10GB (frozen base) | ~8GB | ~18GB |
|
||||
|
||||
### Speed Benchmarking
|
||||
|
||||
Add these measurements to your training script or use the built-in wandb logging:
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
|
||||
# Track step times
|
||||
step_times = []
|
||||
sync_times = []
|
||||
|
||||
for step in range(training_steps):
|
||||
# Measure training step time
|
||||
step_start = time.time()
|
||||
# ... training code ...
|
||||
step_time = time.time() - step_start
|
||||
step_times.append(step_time)
|
||||
|
||||
# Measure sync time (Legacy mode only)
|
||||
if step % vllm_restart_interval == 0:
|
||||
sync_start = time.time()
|
||||
# ... checkpoint + restart vLLM ...
|
||||
sync_time = time.time() - sync_start
|
||||
sync_times.append(sync_time)
|
||||
|
||||
# Print summary
|
||||
print(f"Avg step time: {sum(step_times)/len(step_times):.2f}s")
|
||||
print(f"Avg sync time: {sum(sync_times)/len(sync_times):.2f}s" if sync_times else "No syncs")
|
||||
```
|
||||
|
||||
### Benchmark Script
|
||||
|
||||
Create a benchmark comparing modes:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# benchmark_modes.sh
|
||||
|
||||
MODEL="Qwen/Qwen2.5-3B-Instruct"
|
||||
STEPS=50
|
||||
BATCH=2
|
||||
ACCUM=16
|
||||
|
||||
echo "=== Benchmarking Legacy Mode ==="
|
||||
time python example_trainer/grpo.py \
|
||||
--model-name $MODEL \
|
||||
--weight-bridge-mode none \
|
||||
--training-steps $STEPS \
|
||||
--batch-size $BATCH \
|
||||
--gradient-accumulation-steps $ACCUM \
|
||||
--vllm-restart-interval 10 \
|
||||
2>&1 | tee benchmark_legacy.log
|
||||
|
||||
echo "=== Benchmarking Shared vLLM Mode ==="
|
||||
export LOGDIR=/tmp/bench_shared
|
||||
export NUM_INFERENCE_NODES=0
|
||||
mkdir -p $LOGDIR
|
||||
|
||||
# Start vLLM first
|
||||
python example_trainer/vllm_api_server.py \
|
||||
--model $MODEL --port 9001 --gpu-memory-utilization 0.45 &
|
||||
VLLM_PID=$!
|
||||
sleep 60 # Wait for vLLM to load
|
||||
|
||||
time python example_trainer/grpo.py \
|
||||
--model-name $MODEL \
|
||||
--weight-bridge-mode shared_vllm \
|
||||
--training-steps $STEPS \
|
||||
--batch-size $BATCH \
|
||||
--gradient-accumulation-steps $ACCUM \
|
||||
--num-inference-nodes 0 \
|
||||
2>&1 | tee benchmark_shared.log
|
||||
|
||||
kill $VLLM_PID
|
||||
|
||||
echo "=== Benchmarking LoRA Mode ==="
|
||||
time python example_trainer/grpo.py \
|
||||
--model-name $MODEL \
|
||||
--weight-bridge-mode lora_only \
|
||||
--training-steps $STEPS \
|
||||
--batch-size $BATCH \
|
||||
--gradient-accumulation-steps $ACCUM \
|
||||
--lora-r 16 \
|
||||
--vllm-restart-interval 25 \
|
||||
2>&1 | tee benchmark_lora.log
|
||||
|
||||
echo "=== Summary ==="
|
||||
echo "Check benchmark_*.log for detailed timing"
|
||||
```
|
||||
|
||||
### Expected Benchmark Results
|
||||
|
||||
| Metric | Legacy | Shared vLLM | LoRA |
|
||||
|--------|--------|-------------|------|
|
||||
| **Step time** | ~2-5s | ~2-5s | ~1-3s |
|
||||
| **Sync overhead** | ~30-60s every N steps | ~0ms | ~5-10s every N steps |
|
||||
| **Total time (50 steps, sync every 10)** | ~15-20 min | ~3-5 min | ~5-8 min |
|
||||
| **Peak GPU memory** | ~16GB | ~8GB | ~10GB |
|
||||
| **Checkpoint size** | ~6GB | ~6GB | ~50MB |
|
||||
|
||||
### WandB Metrics to Watch
|
||||
|
||||
If using `--use-wandb`, these metrics are logged:
|
||||
|
||||
| Metric | Description |
|
||||
|--------|-------------|
|
||||
| `train/loss` | GRPO loss |
|
||||
| `train/grad_norm` | Gradient norm |
|
||||
| `train/pos_logp` | Log prob of positive examples |
|
||||
| `train/neg_logp` | Log prob of negative examples |
|
||||
| `train/step_time` | Time per training step |
|
||||
| `train/sync_time` | Time for weight sync (legacy/lora) |
|
||||
|
||||
---
|
||||
|
||||
## Files in This Directory
|
||||
|
||||
| File | Description |
|
||||
|
|
|
|||
|
|
@ -682,9 +682,37 @@ def log_metrics(
|
|||
wandb.log(log_dict, step=step)
|
||||
|
||||
|
||||
def finalize_training(use_wandb: bool) -> None:
|
||||
"""Clean up after training completes."""
|
||||
def finalize_training(
|
||||
use_wandb: bool,
|
||||
training_start_time: Optional[float] = None,
|
||||
mode: str = "unknown",
|
||||
total_steps: int = 0,
|
||||
) -> None:
|
||||
"""Clean up after training and log benchmark summary."""
|
||||
print("\nTraining finished.")
|
||||
|
||||
# Log benchmark summary
|
||||
if training_start_time is not None:
|
||||
total_time = time.time() - training_start_time
|
||||
gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"BENCHMARK SUMMARY ({mode})")
|
||||
print(f"{'='*60}")
|
||||
print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)")
|
||||
print(f" Total steps: {total_steps}")
|
||||
print(f" Avg time per step: {total_time/max(total_steps,1):.2f}s")
|
||||
print(f" Peak GPU memory: {gpu_mem_gb:.2f} GB")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
if use_wandb:
|
||||
wandb.summary["benchmark/total_time_seconds"] = total_time
|
||||
wandb.summary["benchmark/total_time_minutes"] = total_time / 60
|
||||
wandb.summary["benchmark/avg_step_time_seconds"] = total_time / max(total_steps, 1)
|
||||
wandb.summary["benchmark/peak_gpu_memory_gb"] = gpu_mem_gb
|
||||
wandb.summary["benchmark/mode"] = mode
|
||||
wandb.summary["benchmark/total_steps"] = total_steps
|
||||
|
||||
if use_wandb:
|
||||
wandb.finish()
|
||||
|
||||
|
|
@ -697,6 +725,7 @@ def train(config: TrainingConfig):
|
|||
Use weight_bridge_mode='shared_vllm' for in-place weight updates without restarts.
|
||||
"""
|
||||
global vllm_process
|
||||
training_start_time = time.time()
|
||||
|
||||
# === Setup ===
|
||||
use_wandb = setup_wandb(config)
|
||||
|
|
@ -708,6 +737,7 @@ def train(config: TrainingConfig):
|
|||
print(f"{'='*60}")
|
||||
print(f"Training for {config.training_steps} steps on {config.device}")
|
||||
print(f"vLLM restart interval: every {config.vllm_restart_interval} steps")
|
||||
print(f"Save path: {config.save_path}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
os.makedirs(config.save_path, exist_ok=True)
|
||||
|
|
@ -753,8 +783,8 @@ def train(config: TrainingConfig):
|
|||
_check_vllm_health()
|
||||
|
||||
# === Cleanup ===
|
||||
finalize_training(use_wandb)
|
||||
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
|
||||
finalize_training(use_wandb, training_start_time, "legacy", config.training_steps)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -846,6 +876,8 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
"Ensure vllm_weight_bridge.py is in the same directory."
|
||||
)
|
||||
|
||||
training_start_time = time.time()
|
||||
|
||||
# === Setup ===
|
||||
use_wandb = setup_wandb(config)
|
||||
|
||||
|
|
@ -856,6 +888,7 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
print(f"Distributed: rank={config.trainer_rank}/{config.world_size}")
|
||||
print(f"Init method: {config.init_method}")
|
||||
print(f"Inference nodes: {config.num_inference_nodes}")
|
||||
print(f"Save path: {config.save_path}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Initialize weight bridge
|
||||
|
|
@ -907,8 +940,8 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
|
||||
# === Cleanup ===
|
||||
bridge.cleanup()
|
||||
finalize_training(use_wandb)
|
||||
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
|
||||
finalize_training(use_wandb, training_start_time, "shared_vllm", config.training_steps)
|
||||
|
||||
|
||||
def train_lora(config: TrainingConfig):
|
||||
|
|
@ -926,6 +959,8 @@ def train_lora(config: TrainingConfig):
|
|||
"PEFT library required for LoRA mode. Install with: pip install peft"
|
||||
)
|
||||
|
||||
training_start_time = time.time()
|
||||
|
||||
# === Setup ===
|
||||
use_wandb = setup_wandb(config)
|
||||
|
||||
|
|
@ -934,6 +969,7 @@ def train_lora(config: TrainingConfig):
|
|||
print(f"{'='*60}")
|
||||
print(f"Base model: {config.model_name}")
|
||||
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
|
||||
print(f"Save path: {config.save_path}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Load model with LoRA adapters
|
||||
|
|
@ -984,10 +1020,10 @@ def train_lora(config: TrainingConfig):
|
|||
|
||||
# === Cleanup ===
|
||||
_terminate_vllm_process()
|
||||
finalize_training(use_wandb)
|
||||
|
||||
# Save final adapter
|
||||
save_lora_checkpoint(model, config.save_path, config.training_steps, is_final=True)
|
||||
finalize_training(use_wandb, training_start_time, "lora_only", config.training_steps)
|
||||
|
||||
# Also save tokenizer for convenience
|
||||
tokenizer_path = os.path.join(config.save_path, "tokenizer")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue