mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
feedback fixes: shared layers + hard coded values + warmup steps
This commit is contained in:
parent
e1f9b926bb
commit
624b3cdabe
9 changed files with 247 additions and 58 deletions
|
|
@ -151,9 +151,10 @@ python -m example_trainer.grpo \
|
|||
--atropos-url "http://localhost:8002" \
|
||||
--batch-size 4 \
|
||||
--gradient-accumulation-steps 4 \
|
||||
--warmup-steps 20 \
|
||||
--lr 1e-5 \
|
||||
--training-steps 30 \
|
||||
--kl-coef 0.1 \
|
||||
--kl-coef 0.0 \
|
||||
--clip-eps 0.2 \
|
||||
--vllm-restart-interval 5 \
|
||||
--save-path ./lora_checkpoints \
|
||||
|
|
@ -258,7 +259,8 @@ python -m example_trainer.grpo \
|
|||
--vllm-port 9001 \
|
||||
--vllm-config-path /tmp/grpo_training/vllm_bridge_config.json \
|
||||
--atropos-url "http://localhost:8002" \
|
||||
--kl-coef 0.1 \
|
||||
--warmup-steps 20 \
|
||||
--kl-coef 0.0 \
|
||||
--clip-eps 0.2
|
||||
```
|
||||
|
||||
|
|
@ -307,7 +309,7 @@ Only `server_type=vllm` calls the `/generate` endpoint which returns token-level
|
|||
**CRITICAL:** Without these hyperparameters, training WILL collapse (reward hacking):
|
||||
|
||||
```bash
|
||||
--kl-coef 0.1 # Prevents policy from drifting too far from reference
|
||||
--kl-coef 0.0 # Default (disable KL penalty)
|
||||
--clip-eps 0.2 # Limits importance sampling ratio to [0.8, 1.2]
|
||||
```
|
||||
|
||||
|
|
@ -328,6 +330,16 @@ Only `server_type=vllm` calls the `/generate` endpoint which returns token-level
|
|||
- `mean_ratio` diverges far from 1.0
|
||||
- `mean_kl` explodes (> 1.0)
|
||||
|
||||
### 3. Use LR Warmup for Stability
|
||||
|
||||
Use a short linear warmup when training from fresh runs or small batch settings:
|
||||
|
||||
```bash
|
||||
--warmup-steps 20
|
||||
```
|
||||
|
||||
This linearly ramps learning rate from 0 to `--lr` over the first N optimizer steps.
|
||||
|
||||
**Healthy training metrics:**
|
||||
- `mean_ratio`: 0.8 - 1.2 (close to 1.0)
|
||||
- `mean_kl`: 0.01 - 0.1
|
||||
|
|
@ -355,9 +367,9 @@ The trainer supports multiple optimizer options to trade off between speed, memo
|
|||
|
||||
| Optimizer | GPU Memory for States | Speed | Precision | Dependencies |
|
||||
|-----------|----------------------|-------|-----------|--------------|
|
||||
| `adamw` | ~32GB (for 8B model) | Fastest | Full FP32 | None |
|
||||
| `adamw_8bit` (default) | ~8GB | Fast | 8-bit quantized | `bitsandbytes` |
|
||||
| `adafactor` | ~8GB | Fast | Full (no momentum) | `transformers` |
|
||||
| `adamw` | Highest | Fastest | Full FP32 | None |
|
||||
| `adamw_8bit` (default) | Lower | Fast | 8-bit quantized | `bitsandbytes` |
|
||||
| `adafactor` | Lower | Fast | Full (no momentum) | `transformers` |
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
|
|
@ -571,13 +583,15 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
|
|||
| `--checkpoint-interval` | 3 | Save checkpoint every N steps (0 = final only) |
|
||||
| `--batch-size` | 2 | Micro-batch size |
|
||||
| `--gradient-accumulation-steps` | 32 | Effective batch = batch × accum |
|
||||
| `--warmup-steps` | 0 | Linear LR warmup steps (0 disables warmup) |
|
||||
| `--seq-len` | 2048 | Maximum sequence length |
|
||||
| `--train-layer-indices` | None | Optional full-model layer filter for shared/legacy modes (examples: `20-31`, `0-3,28-31`) |
|
||||
|
||||
### GRPO Hyperparameters
|
||||
|
||||
| Argument | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `--kl-coef` | 0.1 | KL penalty strength (higher = more conservative) |
|
||||
| `--kl-coef` | 0.0 | KL penalty strength (higher = more conservative) |
|
||||
| `--clip-eps` | 0.2 | PPO clipping range [1-ε, 1+ε] |
|
||||
| `--lr` | 1e-5 | Learning rate (NOT --learning-rate) |
|
||||
| `--no-reference-logprobs` | False | Disable GRPO reference logprobs (falls back to REINFORCE-style updates) |
|
||||
|
|
@ -592,9 +606,9 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
|
|||
| `--lora-target-modules` | None | Module names to apply LoRA (`None` falls back to `q_proj v_proj`) |
|
||||
| `--lora-layer-indices` | None | Optional layer filter (examples: `20-31`, `0-3,28-31`) |
|
||||
|
||||
### LoRA Layer Index Guide (by Architecture)
|
||||
### Layer Index Guide (by Architecture)
|
||||
|
||||
`--lora-layer-indices` is model-dependent. Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another.
|
||||
Layer-index arguments are model-dependent (`--train-layer-indices` for full/shared modes, `--lora-layer-indices` for LoRA modes). Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another.
|
||||
|
||||
| Architecture family | Common config fields | Typical layer list path | Notes |
|
||||
|---------------------|----------------------|-------------------------|-------|
|
||||
|
|
@ -628,10 +642,10 @@ PY
|
|||
|
||||
If your model has `N` layers:
|
||||
|
||||
- Full layers: omit `--lora-layer-indices`
|
||||
- Top 25%: `--lora-layer-indices {int(0.75*N)}-{N-1}`
|
||||
- Top 50%: `--lora-layer-indices {int(0.5*N)}-{N-1}`
|
||||
- Last 12 layers: `--lora-layer-indices {N-12}-{N-1}` (if `N >= 12`)
|
||||
- Full layers: omit `--train-layer-indices`
|
||||
- Top 25%: `--train-layer-indices {int(0.75*N)}-{N-1}`
|
||||
- Top 50%: `--train-layer-indices {int(0.5*N)}-{N-1}`
|
||||
- Last 12 layers: `--train-layer-indices {N-12}-{N-1}` (if `N >= 12`)
|
||||
|
||||
### vLLM Arguments
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue