feedback fixes: shared layers + hard coded values + warmup steps

This commit is contained in:
Jai Suphavadeeprasit 2026-02-24 10:28:44 -05:00
parent e1f9b926bb
commit 624b3cdabe
9 changed files with 247 additions and 58 deletions

View file

@ -151,9 +151,10 @@ python -m example_trainer.grpo \
--atropos-url "http://localhost:8002" \
--batch-size 4 \
--gradient-accumulation-steps 4 \
--warmup-steps 20 \
--lr 1e-5 \
--training-steps 30 \
--kl-coef 0.1 \
--kl-coef 0.0 \
--clip-eps 0.2 \
--vllm-restart-interval 5 \
--save-path ./lora_checkpoints \
@ -258,7 +259,8 @@ python -m example_trainer.grpo \
--vllm-port 9001 \
--vllm-config-path /tmp/grpo_training/vllm_bridge_config.json \
--atropos-url "http://localhost:8002" \
--kl-coef 0.1 \
--warmup-steps 20 \
--kl-coef 0.0 \
--clip-eps 0.2
```
@ -307,7 +309,7 @@ Only `server_type=vllm` calls the `/generate` endpoint which returns token-level
**CRITICAL:** Without these hyperparameters, training WILL collapse (reward hacking):
```bash
--kl-coef 0.1 # Prevents policy from drifting too far from reference
--kl-coef 0.0 # Default (disable KL penalty)
--clip-eps 0.2 # Limits importance sampling ratio to [0.8, 1.2]
```
@ -328,6 +330,16 @@ Only `server_type=vllm` calls the `/generate` endpoint which returns token-level
- `mean_ratio` diverges far from 1.0
- `mean_kl` explodes (> 1.0)
### 3. Use LR Warmup for Stability
Use a short linear warmup when training from fresh runs or small batch settings:
```bash
--warmup-steps 20
```
This linearly ramps learning rate from 0 to `--lr` over the first N optimizer steps.
**Healthy training metrics:**
- `mean_ratio`: 0.8 - 1.2 (close to 1.0)
- `mean_kl`: 0.01 - 0.1
@ -355,9 +367,9 @@ The trainer supports multiple optimizer options to trade off between speed, memo
| Optimizer | GPU Memory for States | Speed | Precision | Dependencies |
|-----------|----------------------|-------|-----------|--------------|
| `adamw` | ~32GB (for 8B model) | Fastest | Full FP32 | None |
| `adamw_8bit` (default) | ~8GB | Fast | 8-bit quantized | `bitsandbytes` |
| `adafactor` | ~8GB | Fast | Full (no momentum) | `transformers` |
| `adamw` | Highest | Fastest | Full FP32 | None |
| `adamw_8bit` (default) | Lower | Fast | 8-bit quantized | `bitsandbytes` |
| `adafactor` | Lower | Fast | Full (no momentum) | `transformers` |
**Usage:**
```bash
@ -571,13 +583,15 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
| `--checkpoint-interval` | 3 | Save checkpoint every N steps (0 = final only) |
| `--batch-size` | 2 | Micro-batch size |
| `--gradient-accumulation-steps` | 32 | Effective batch = batch × accum |
| `--warmup-steps` | 0 | Linear LR warmup steps (0 disables warmup) |
| `--seq-len` | 2048 | Maximum sequence length |
| `--train-layer-indices` | None | Optional full-model layer filter for shared/legacy modes (examples: `20-31`, `0-3,28-31`) |
### GRPO Hyperparameters
| Argument | Default | Description |
|----------|---------|-------------|
| `--kl-coef` | 0.1 | KL penalty strength (higher = more conservative) |
| `--kl-coef` | 0.0 | KL penalty strength (higher = more conservative) |
| `--clip-eps` | 0.2 | PPO clipping range [1-ε, 1+ε] |
| `--lr` | 1e-5 | Learning rate (NOT --learning-rate) |
| `--no-reference-logprobs` | False | Disable GRPO reference logprobs (falls back to REINFORCE-style updates) |
@ -592,9 +606,9 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
| `--lora-target-modules` | None | Module names to apply LoRA (`None` falls back to `q_proj v_proj`) |
| `--lora-layer-indices` | None | Optional layer filter (examples: `20-31`, `0-3,28-31`) |
### LoRA Layer Index Guide (by Architecture)
### Layer Index Guide (by Architecture)
`--lora-layer-indices` is model-dependent. Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another.
Layer-index arguments are model-dependent (`--train-layer-indices` for full/shared modes, `--lora-layer-indices` for LoRA modes). Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another.
| Architecture family | Common config fields | Typical layer list path | Notes |
|---------------------|----------------------|-------------------------|-------|
@ -628,10 +642,10 @@ PY
If your model has `N` layers:
- Full layers: omit `--lora-layer-indices`
- Top 25%: `--lora-layer-indices {int(0.75*N)}-{N-1}`
- Top 50%: `--lora-layer-indices {int(0.5*N)}-{N-1}`
- Last 12 layers: `--lora-layer-indices {N-12}-{N-1}` (if `N >= 12`)
- Full layers: omit `--train-layer-indices`
- Top 25%: `--train-layer-indices {int(0.75*N)}-{N-1}`
- Top 50%: `--train-layer-indices {int(0.5*N)}-{N-1}`
- Last 12 layers: `--train-layer-indices {N-12}-{N-1}` (if `N >= 12`)
### vLLM Arguments