mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
feedback fixes: shared layers + hard coded values + warmup steps
This commit is contained in:
parent
e1f9b926bb
commit
624b3cdabe
9 changed files with 247 additions and 58 deletions
|
|
@ -151,9 +151,10 @@ python -m example_trainer.grpo \
|
|||
--atropos-url "http://localhost:8002" \
|
||||
--batch-size 4 \
|
||||
--gradient-accumulation-steps 4 \
|
||||
--warmup-steps 20 \
|
||||
--lr 1e-5 \
|
||||
--training-steps 30 \
|
||||
--kl-coef 0.1 \
|
||||
--kl-coef 0.0 \
|
||||
--clip-eps 0.2 \
|
||||
--vllm-restart-interval 5 \
|
||||
--save-path ./lora_checkpoints \
|
||||
|
|
@ -258,7 +259,8 @@ python -m example_trainer.grpo \
|
|||
--vllm-port 9001 \
|
||||
--vllm-config-path /tmp/grpo_training/vllm_bridge_config.json \
|
||||
--atropos-url "http://localhost:8002" \
|
||||
--kl-coef 0.1 \
|
||||
--warmup-steps 20 \
|
||||
--kl-coef 0.0 \
|
||||
--clip-eps 0.2
|
||||
```
|
||||
|
||||
|
|
@ -307,7 +309,7 @@ Only `server_type=vllm` calls the `/generate` endpoint which returns token-level
|
|||
**CRITICAL:** Without these hyperparameters, training WILL collapse (reward hacking):
|
||||
|
||||
```bash
|
||||
--kl-coef 0.1 # Prevents policy from drifting too far from reference
|
||||
--kl-coef 0.0 # Default (disable KL penalty)
|
||||
--clip-eps 0.2 # Limits importance sampling ratio to [0.8, 1.2]
|
||||
```
|
||||
|
||||
|
|
@ -328,6 +330,16 @@ Only `server_type=vllm` calls the `/generate` endpoint which returns token-level
|
|||
- `mean_ratio` diverges far from 1.0
|
||||
- `mean_kl` explodes (> 1.0)
|
||||
|
||||
### 3. Use LR Warmup for Stability
|
||||
|
||||
Use a short linear warmup when training from fresh runs or small batch settings:
|
||||
|
||||
```bash
|
||||
--warmup-steps 20
|
||||
```
|
||||
|
||||
This linearly ramps learning rate from 0 to `--lr` over the first N optimizer steps.
|
||||
|
||||
**Healthy training metrics:**
|
||||
- `mean_ratio`: 0.8 - 1.2 (close to 1.0)
|
||||
- `mean_kl`: 0.01 - 0.1
|
||||
|
|
@ -355,9 +367,9 @@ The trainer supports multiple optimizer options to trade off between speed, memo
|
|||
|
||||
| Optimizer | GPU Memory for States | Speed | Precision | Dependencies |
|
||||
|-----------|----------------------|-------|-----------|--------------|
|
||||
| `adamw` | ~32GB (for 8B model) | Fastest | Full FP32 | None |
|
||||
| `adamw_8bit` (default) | ~8GB | Fast | 8-bit quantized | `bitsandbytes` |
|
||||
| `adafactor` | ~8GB | Fast | Full (no momentum) | `transformers` |
|
||||
| `adamw` | Highest | Fastest | Full FP32 | None |
|
||||
| `adamw_8bit` (default) | Lower | Fast | 8-bit quantized | `bitsandbytes` |
|
||||
| `adafactor` | Lower | Fast | Full (no momentum) | `transformers` |
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
|
|
@ -571,13 +583,15 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
|
|||
| `--checkpoint-interval` | 3 | Save checkpoint every N steps (0 = final only) |
|
||||
| `--batch-size` | 2 | Micro-batch size |
|
||||
| `--gradient-accumulation-steps` | 32 | Effective batch = batch × accum |
|
||||
| `--warmup-steps` | 0 | Linear LR warmup steps (0 disables warmup) |
|
||||
| `--seq-len` | 2048 | Maximum sequence length |
|
||||
| `--train-layer-indices` | None | Optional full-model layer filter for shared/legacy modes (examples: `20-31`, `0-3,28-31`) |
|
||||
|
||||
### GRPO Hyperparameters
|
||||
|
||||
| Argument | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `--kl-coef` | 0.1 | KL penalty strength (higher = more conservative) |
|
||||
| `--kl-coef` | 0.0 | KL penalty strength (higher = more conservative) |
|
||||
| `--clip-eps` | 0.2 | PPO clipping range [1-ε, 1+ε] |
|
||||
| `--lr` | 1e-5 | Learning rate (NOT --learning-rate) |
|
||||
| `--no-reference-logprobs` | False | Disable GRPO reference logprobs (falls back to REINFORCE-style updates) |
|
||||
|
|
@ -592,9 +606,9 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
|
|||
| `--lora-target-modules` | None | Module names to apply LoRA (`None` falls back to `q_proj v_proj`) |
|
||||
| `--lora-layer-indices` | None | Optional layer filter (examples: `20-31`, `0-3,28-31`) |
|
||||
|
||||
### LoRA Layer Index Guide (by Architecture)
|
||||
### Layer Index Guide (by Architecture)
|
||||
|
||||
`--lora-layer-indices` is model-dependent. Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another.
|
||||
Layer-index arguments are model-dependent (`--train-layer-indices` for full/shared modes, `--lora-layer-indices` for LoRA modes). Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another.
|
||||
|
||||
| Architecture family | Common config fields | Typical layer list path | Notes |
|
||||
|---------------------|----------------------|-------------------------|-------|
|
||||
|
|
@ -628,10 +642,10 @@ PY
|
|||
|
||||
If your model has `N` layers:
|
||||
|
||||
- Full layers: omit `--lora-layer-indices`
|
||||
- Top 25%: `--lora-layer-indices {int(0.75*N)}-{N-1}`
|
||||
- Top 50%: `--lora-layer-indices {int(0.5*N)}-{N-1}`
|
||||
- Last 12 layers: `--lora-layer-indices {N-12}-{N-1}` (if `N >= 12`)
|
||||
- Full layers: omit `--train-layer-indices`
|
||||
- Top 25%: `--train-layer-indices {int(0.75*N)}-{N-1}`
|
||||
- Top 50%: `--train-layer-indices {int(0.5*N)}-{N-1}`
|
||||
- Last 12 layers: `--train-layer-indices {N-12}-{N-1}` (if `N >= 12`)
|
||||
|
||||
### vLLM Arguments
|
||||
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ def save_checkpoint(
|
|||
torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
|
||||
model.config.save_pretrained(checkpoint_path)
|
||||
|
||||
# CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory!
|
||||
# CRITICAL: Clean up the copied state_dict to free significant GPU memory.
|
||||
del state_dict
|
||||
import gc
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from .config import TrainingConfig
|
|||
# =============================================================================
|
||||
|
||||
|
||||
def _parse_lora_layer_indices(value: str) -> Optional[List[int]]:
|
||||
def _parse_layer_indices(value: str) -> Optional[List[int]]:
|
||||
"""
|
||||
Parse LoRA layer indices from comma/range syntax.
|
||||
|
||||
|
|
@ -110,6 +110,12 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
|
|||
default=32,
|
||||
help="Number of gradient accumulation steps",
|
||||
)
|
||||
group.add_argument(
|
||||
"--warmup-steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Linear LR warmup steps (0 disables warmup).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--optimizer",
|
||||
type=str,
|
||||
|
|
@ -118,6 +124,16 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
|
|||
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
|
||||
"'adafactor' (no momentum)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--adafactor-scale-parameter",
|
||||
action="store_true",
|
||||
help="Enable Adafactor scale_parameter behavior (only used when --optimizer adafactor).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--adafactor-relative-step",
|
||||
action="store_true",
|
||||
help="Enable Adafactor relative_step behavior (only used when --optimizer adafactor).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
|
|
@ -144,8 +160,8 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None:
|
|||
group.add_argument(
|
||||
"--kl-coef",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="KL divergence penalty coefficient (beta). Higher = more conservative.",
|
||||
default=0.0,
|
||||
help="Sampled-token KL-like regularization coefficient. Higher = more conservative.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--clip-eps",
|
||||
|
|
@ -256,6 +272,15 @@ def add_mode_args(parser: argparse.ArgumentParser) -> None:
|
|||
default=None,
|
||||
help="Explicit path to vllm_bridge_config.json (auto-detected if not provided)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--train-layer-indices",
|
||||
type=_parse_layer_indices,
|
||||
default=None,
|
||||
help=(
|
||||
"Optional transformer layer indices to train in full/shared modes, e.g. "
|
||||
"'20-31' or '0-3,28-31'. If omitted, all layers are trainable."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def add_lora_args(parser: argparse.ArgumentParser) -> None:
|
||||
|
|
@ -275,7 +300,7 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
|
|||
)
|
||||
group.add_argument(
|
||||
"--lora-layer-indices",
|
||||
type=_parse_lora_layer_indices,
|
||||
type=_parse_layer_indices,
|
||||
default=None,
|
||||
help=(
|
||||
"Optional layer indices to apply LoRA to, e.g. '20-31' or "
|
||||
|
|
@ -403,14 +428,17 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
warmup_steps=getattr(args, "warmup_steps", 0),
|
||||
optimizer=args.optimizer,
|
||||
device=args.device,
|
||||
save_path=args.save_path,
|
||||
checkpoint_interval=getattr(args, "checkpoint_interval", 3),
|
||||
# GRPO/PPO hyperparameters
|
||||
kl_coef=getattr(args, "kl_coef", 0.1),
|
||||
kl_coef=getattr(args, "kl_coef", 0.0),
|
||||
clip_eps=getattr(args, "clip_eps", 0.2),
|
||||
use_reference_logprobs=not getattr(args, "no_reference_logprobs", False),
|
||||
adafactor_scale_parameter=getattr(args, "adafactor_scale_parameter", False),
|
||||
adafactor_relative_step=getattr(args, "adafactor_relative_step", False),
|
||||
# vLLM settings
|
||||
vllm_restart_interval=getattr(args, "vllm_restart_interval", 3),
|
||||
vllm_port=args.vllm_port,
|
||||
|
|
@ -422,6 +450,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|||
wandb_project=args.wandb_project,
|
||||
wandb_group=getattr(args, "wandb_group", None),
|
||||
weight_bridge_mode=getattr(args, "weight_bridge_mode", "none"),
|
||||
train_layer_indices=getattr(args, "train_layer_indices", None),
|
||||
trainer_rank=getattr(args, "trainer_rank", 0),
|
||||
world_size=getattr(args, "world_size", 1),
|
||||
init_method=getattr(args, "init_method", "env://"),
|
||||
|
|
|
|||
|
|
@ -32,21 +32,41 @@ class TrainingConfig(BaseModel):
|
|||
gradient_accumulation_steps: int = Field(
|
||||
32, description="Number of gradient accumulation steps"
|
||||
)
|
||||
warmup_steps: int = Field(
|
||||
0,
|
||||
description=(
|
||||
"Number of initial optimizer steps for linear LR warmup. "
|
||||
"0 disables warmup."
|
||||
),
|
||||
)
|
||||
optimizer: Literal["adamw", "adamw_8bit", "adafactor"] = Field(
|
||||
"adamw_8bit",
|
||||
description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), "
|
||||
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
|
||||
"'adafactor' (no momentum, ~8GB GPU)",
|
||||
description="Optimizer to use: 'adamw' (full precision), "
|
||||
"'adamw_8bit' (8-bit states, requires bitsandbytes), "
|
||||
"'adafactor' (Adafactor optimizer)",
|
||||
)
|
||||
adafactor_scale_parameter: bool = Field(
|
||||
False,
|
||||
description=(
|
||||
"Whether to enable Adafactor scale_parameter behavior when using "
|
||||
"optimizer='adafactor'."
|
||||
),
|
||||
)
|
||||
adafactor_relative_step: bool = Field(
|
||||
False,
|
||||
description=(
|
||||
"Whether to enable Adafactor relative_step behavior when using "
|
||||
"optimizer='adafactor'."
|
||||
),
|
||||
)
|
||||
|
||||
# === GRPO/PPO Hyperparameters ===
|
||||
kl_coef: float = Field(
|
||||
0.1,
|
||||
0.0,
|
||||
description=(
|
||||
"KL divergence penalty coefficient (beta). "
|
||||
"Controls how much the policy can deviate from the reference (inference-time) policy. "
|
||||
"Higher values = more conservative updates, prevents reward hacking. "
|
||||
"Set to 0 to disable KL penalty (not recommended)."
|
||||
"Coefficient for sampled-token KL-like regularization against rollout/reference "
|
||||
"logprobs. Higher values make updates more conservative. "
|
||||
"Set to 0 to disable this term."
|
||||
),
|
||||
)
|
||||
clip_eps: float = Field(
|
||||
|
|
@ -121,6 +141,13 @@ class TrainingConfig(BaseModel):
|
|||
"'none': legacy mode, restart vLLM with new checkpoint files."
|
||||
),
|
||||
)
|
||||
train_layer_indices: Optional[List[int]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional list of transformer layer indices to train in shared/legacy "
|
||||
"full-model modes. If None, all layers are trainable."
|
||||
),
|
||||
)
|
||||
|
||||
# === Distributed Training Configuration ===
|
||||
trainer_rank: int = Field(
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ Handles:
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
|
@ -119,6 +120,7 @@ def load_model_and_tokenizer(
|
|||
|
||||
if model is not None:
|
||||
print("[Setup] ✓ Single-copy mode active - using vLLM's tensors directly!")
|
||||
_apply_train_layer_filter(model, config.train_layer_indices)
|
||||
# Enable gradient checkpointing to save memory (was missing before!)
|
||||
_setup_gradient_checkpointing(model, config)
|
||||
model.train()
|
||||
|
|
@ -141,6 +143,7 @@ def load_model_and_tokenizer(
|
|||
print("[Setup] Loading model for legacy mode...")
|
||||
model = _load_model_with_attention(config.model_name)
|
||||
model.to(config.device)
|
||||
_apply_train_layer_filter(model, config.train_layer_indices)
|
||||
|
||||
# Enable gradient checkpointing
|
||||
_setup_gradient_checkpointing(model, config)
|
||||
|
|
@ -242,6 +245,59 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
|
|||
return model
|
||||
|
||||
|
||||
def _apply_train_layer_filter(
|
||||
model: torch.nn.Module, layer_indices: Optional[list[int]]
|
||||
) -> None:
|
||||
"""
|
||||
Freeze all parameters except selected transformer block indices.
|
||||
|
||||
Applies to full-model modes (shared_vllm / legacy), not LoRA.
|
||||
"""
|
||||
if layer_indices is None:
|
||||
return
|
||||
|
||||
num_hidden_layers = getattr(model.config, "num_hidden_layers", None)
|
||||
if num_hidden_layers is None:
|
||||
num_hidden_layers = getattr(model.config, "n_layer", None)
|
||||
if num_hidden_layers is None:
|
||||
raise RuntimeError(
|
||||
"Model config does not expose num_hidden_layers or n_layer; "
|
||||
"cannot validate --train-layer-indices for this architecture."
|
||||
)
|
||||
|
||||
invalid = [idx for idx in layer_indices if idx >= num_hidden_layers]
|
||||
if invalid:
|
||||
raise ValueError(
|
||||
f"Invalid --train-layer-indices {invalid} for model with "
|
||||
f"{num_hidden_layers} layers (valid range: 0-{num_hidden_layers - 1})"
|
||||
)
|
||||
|
||||
allowed = set(layer_indices)
|
||||
layer_pattern = re.compile(r"\.layers\.(\d+)\.")
|
||||
trainable_params = 0
|
||||
total_params = 0
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
match = layer_pattern.search(name)
|
||||
should_train = bool(match and int(match.group(1)) in allowed)
|
||||
param.requires_grad_(should_train)
|
||||
total_params += param.numel()
|
||||
if should_train:
|
||||
trainable_params += param.numel()
|
||||
|
||||
if trainable_params == 0:
|
||||
raise RuntimeError(
|
||||
"--train-layer-indices did not match any trainable parameters. "
|
||||
"Check architecture naming and selected indices."
|
||||
)
|
||||
|
||||
pct = 100.0 * trainable_params / max(total_params, 1)
|
||||
print(
|
||||
f"[Setup] Training only transformer layers {sorted(allowed)} "
|
||||
f"({trainable_params}/{total_params} params, {pct:.2f}%)"
|
||||
)
|
||||
|
||||
|
||||
def _setup_gradient_checkpointing(
|
||||
model: torch.nn.Module, config: TrainingConfig
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -194,6 +194,7 @@ def main():
|
|||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
warmup_steps=getattr(args, "warmup_steps", 0),
|
||||
optimizer=args.optimizer,
|
||||
device="cuda:0", # Always 0 since we set CUDA_VISIBLE_DEVICES
|
||||
save_path=args.save_path,
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ set -euo pipefail
|
|||
|
||||
# Runs three GSM8K test trainings with separate infra/ports:
|
||||
# 1) shared_vllm
|
||||
# 2) lora_only (+ layer filtering support)
|
||||
# 3) lora_restart (+ layer filtering support)
|
||||
# 2) lora_only
|
||||
# 3) lora_restart
|
||||
#
|
||||
# Usage:
|
||||
# chmod +x example_trainer/run_gsm8k_lora_matrix.sh
|
||||
|
|
@ -12,8 +12,11 @@ set -euo pipefail
|
|||
#
|
||||
# Optional environment overrides:
|
||||
# MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
# TRAINING_STEPS=10
|
||||
# LORA_LAYER_INDICES="20-31"
|
||||
# TRAINING_STEPS=30
|
||||
# WARMUP_STEPS=5
|
||||
# MATRIX_TARGETED=1 # auto-enable layer targeting defaults for smoke tests
|
||||
# SHARED_LAYER_INDICES="0-3" # overrides MATRIX_TARGETED default
|
||||
# LORA_LAYER_INDICES="0-3" # overrides MATRIX_TARGETED default
|
||||
# WANDB_PROJECT="gsm8k-grpo-smoke"
|
||||
# WANDB_GROUP="gsm8k-$(date +%Y%m%d-%H%M%S)"
|
||||
# START_API_PORT=8002
|
||||
|
|
@ -37,11 +40,12 @@ cd "$ROOT_DIR"
|
|||
|
||||
PYTHON_BIN="${PYTHON_BIN:-python3}"
|
||||
MODEL_NAME="${MODEL_NAME:-NousResearch/Hermes-3-Llama-3.1-8B}"
|
||||
TRAINING_STEPS="${TRAINING_STEPS:-10}"
|
||||
TRAINING_STEPS="${TRAINING_STEPS:-30}"
|
||||
BATCH_SIZE="${BATCH_SIZE:-4}"
|
||||
GRAD_ACCUM="${GRAD_ACCUM:-4}"
|
||||
LR="${LR:-1e-5}"
|
||||
KL_COEF="${KL_COEF:-0.1}"
|
||||
WARMUP_STEPS="${WARMUP_STEPS:-5}"
|
||||
KL_COEF="${KL_COEF:-0.0}"
|
||||
CLIP_EPS="${CLIP_EPS:-0.2}"
|
||||
GPU_MEMORY_UTILIZATION="${GPU_MEMORY_UTILIZATION:-0.45}"
|
||||
MAX_MODEL_LEN="${MAX_MODEL_LEN:-4096}"
|
||||
|
|
@ -50,7 +54,13 @@ LORA_R="${LORA_R:-16}"
|
|||
LORA_ALPHA="${LORA_ALPHA:-32}"
|
||||
LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
|
||||
LORA_TARGET_MODULES="${LORA_TARGET_MODULES:-q_proj v_proj}"
|
||||
MATRIX_TARGETED="${MATRIX_TARGETED:-1}"
|
||||
SHARED_LAYER_INDICES="${SHARED_LAYER_INDICES:-}"
|
||||
LORA_LAYER_INDICES="${LORA_LAYER_INDICES:-}"
|
||||
if [[ "$MATRIX_TARGETED" == "1" ]]; then
|
||||
SHARED_LAYER_INDICES="${SHARED_LAYER_INDICES:-0-3}"
|
||||
LORA_LAYER_INDICES="${LORA_LAYER_INDICES:-0-3}"
|
||||
fi
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-gsm8k-grpo-smoke}"
|
||||
WANDB_GROUP="${WANDB_GROUP:-gsm8k-$(date +%Y%m%d-%H%M%S)}"
|
||||
START_API_PORT="${START_API_PORT:-8002}"
|
||||
|
|
@ -153,6 +163,12 @@ cleanup_run() {
|
|||
run_ports=()
|
||||
}
|
||||
|
||||
add_shared_layer_flag() {
|
||||
if [[ -n "$SHARED_LAYER_INDICES" ]]; then
|
||||
echo "--train-layer-indices" "$SHARED_LAYER_INDICES"
|
||||
fi
|
||||
}
|
||||
|
||||
add_lora_layer_flag() {
|
||||
if [[ -n "$LORA_LAYER_INDICES" ]]; then
|
||||
echo "--lora-layer-indices" "$LORA_LAYER_INDICES"
|
||||
|
|
@ -165,6 +181,7 @@ common_trainer_flags() {
|
|||
--training-steps "$TRAINING_STEPS" \
|
||||
--batch-size "$BATCH_SIZE" \
|
||||
--gradient-accumulation-steps "$GRAD_ACCUM" \
|
||||
--warmup-steps "$WARMUP_STEPS" \
|
||||
--lr "$LR" \
|
||||
--kl-coef "$KL_COEF" \
|
||||
--clip-eps "$CLIP_EPS" \
|
||||
|
|
@ -252,7 +269,8 @@ run_shared_vllm() {
|
|||
--save-path "$save_dir" \
|
||||
--vllm-port "$vllm_port" \
|
||||
--vllm-config-path "${bridge_dir}/vllm_bridge_config.json" \
|
||||
--atropos-url "http://localhost:${api_port}"
|
||||
--atropos-url "http://localhost:${api_port}" \
|
||||
$(add_shared_layer_flag)
|
||||
printf '\n'
|
||||
log "[DRY RUN] trainer log path: $mode_dir/trainer.log"
|
||||
else
|
||||
|
|
@ -263,7 +281,8 @@ run_shared_vllm() {
|
|||
--save-path "$save_dir" \
|
||||
--vllm-port "$vllm_port" \
|
||||
--vllm-config-path "${bridge_dir}/vllm_bridge_config.json" \
|
||||
--atropos-url "http://localhost:${api_port}" | tee "$mode_dir/trainer.log"
|
||||
--atropos-url "http://localhost:${api_port}" \
|
||||
$(add_shared_layer_flag) | tee "$mode_dir/trainer.log"
|
||||
fi
|
||||
|
||||
cleanup_run
|
||||
|
|
@ -419,6 +438,8 @@ log "Model: $MODEL_NAME"
|
|||
log "W&B project/group: $WANDB_PROJECT / $WANDB_GROUP"
|
||||
log "Dry run mode: $DRY_RUN"
|
||||
log "Output base directory (logs + saves): $OUTPUT_BASE_DIR"
|
||||
log "Warmup steps: $WARMUP_STEPS"
|
||||
log "Targeted-layer matrix profile: $MATRIX_TARGETED"
|
||||
log "Port plan:"
|
||||
log " shared_vllm: run-api=${SHARED_API_PORT}, vllm=${SHARED_VLLM_PORT}"
|
||||
log " lora_only: run-api=${LORA_ONLY_API_PORT}, vllm=${LORA_ONLY_VLLM_PORT}"
|
||||
|
|
@ -427,6 +448,11 @@ log "GPU plan:"
|
|||
log " shared_vllm: trainer+vllm on GPU ${SHARED_GPU} (required for shared weights)"
|
||||
log " lora_only: trainer GPU ${LORA_ONLY_TRAINER_GPU}, vllm GPU ${LORA_ONLY_VLLM_GPU}"
|
||||
log " lora_restart: trainer GPU ${LORA_RESTART_TRAINER_GPU}, vllm GPU ${LORA_RESTART_VLLM_GPU}"
|
||||
if [[ -n "$SHARED_LAYER_INDICES" ]]; then
|
||||
log "Shared-model train layer indices: $SHARED_LAYER_INDICES"
|
||||
else
|
||||
log "Shared-model train layer indices: all layers"
|
||||
fi
|
||||
if [[ -n "$LORA_LAYER_INDICES" ]]; then
|
||||
log "LoRA layer indices: $LORA_LAYER_INDICES"
|
||||
else
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import os
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import requests
|
||||
|
|
@ -21,16 +22,22 @@ from torch.optim import AdamW
|
|||
from .api import check_atropos_api, register_trainer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
||||
"""
|
||||
Create optimizer based on config.optimizer setting.
|
||||
|
||||
Options:
|
||||
- 'adamw': Standard AdamW (full precision, ~32GB GPU for 8B model)
|
||||
- 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes)
|
||||
- 'adafactor': Adafactor without momentum (~8GB GPU, no extra dependencies)
|
||||
- 'adamw': Standard AdamW
|
||||
- 'adamw_8bit': 8-bit AdamW from bitsandbytes (requires bitsandbytes)
|
||||
- 'adafactor': Adafactor optimizer (requires transformers)
|
||||
"""
|
||||
return create_optimizer_for_params(model.parameters(), config)
|
||||
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||
if not trainable_params:
|
||||
raise RuntimeError("No trainable parameters found for optimizer creation.")
|
||||
return create_optimizer_for_params(trainable_params, config)
|
||||
|
||||
|
||||
def create_optimizer_for_params(
|
||||
|
|
@ -38,36 +45,46 @@ def create_optimizer_for_params(
|
|||
) -> torch.optim.Optimizer:
|
||||
"""Create optimizer for a specific parameter iterable."""
|
||||
params = list(params)
|
||||
if not params:
|
||||
raise RuntimeError("Optimizer received an empty parameter list.")
|
||||
|
||||
if config.optimizer == "adamw_8bit":
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
optimizer = bnb.optim.AdamW8bit(params, lr=config.lr)
|
||||
print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)")
|
||||
logger.info("[Setup] Using 8-bit AdamW optimizer")
|
||||
return optimizer
|
||||
except ImportError:
|
||||
print("[Setup] WARNING: bitsandbytes not installed, falling back to AdamW")
|
||||
print("[Setup] Install with: pip install bitsandbytes")
|
||||
logger.warning(
|
||||
"[Setup] bitsandbytes not installed, falling back to AdamW"
|
||||
)
|
||||
logger.info("[Setup] Install with: pip install bitsandbytes")
|
||||
|
||||
if config.optimizer == "adafactor":
|
||||
try:
|
||||
from transformers.optimization import Adafactor
|
||||
|
||||
scale_parameter = getattr(config, "adafactor_scale_parameter", False)
|
||||
relative_step = getattr(config, "adafactor_relative_step", False)
|
||||
optimizer = Adafactor(
|
||||
params,
|
||||
lr=config.lr,
|
||||
scale_parameter=False,
|
||||
relative_step=False,
|
||||
scale_parameter=scale_parameter,
|
||||
relative_step=relative_step,
|
||||
)
|
||||
logger.info(
|
||||
"[Setup] Using Adafactor optimizer (scale_parameter=%s, relative_step=%s)",
|
||||
scale_parameter,
|
||||
relative_step,
|
||||
)
|
||||
print("[Setup] Using Adafactor (no momentum, saves ~24GB)")
|
||||
return optimizer
|
||||
except ImportError:
|
||||
print("[Setup] WARNING: transformers Adafactor not available, using AdamW")
|
||||
logger.warning("[Setup] transformers Adafactor unavailable, using AdamW")
|
||||
|
||||
# Default: standard AdamW
|
||||
optimizer = AdamW(params, lr=config.lr)
|
||||
print("[Setup] Using standard AdamW (requires ~32GB for optimizer states)")
|
||||
logger.info("[Setup] Using standard AdamW optimizer")
|
||||
return optimizer
|
||||
|
||||
|
||||
|
|
@ -176,6 +193,7 @@ def train_legacy(config: TrainingConfig):
|
|||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
step_idx=step,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
|
|
@ -322,6 +340,7 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
step_idx=step,
|
||||
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
|
|
@ -481,6 +500,7 @@ def train_lora(config: TrainingConfig):
|
|||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
step_idx=step,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
|
|
@ -702,6 +722,7 @@ def train_lora_restart(config: TrainingConfig):
|
|||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
step_idx=step,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def compute_grpo_loss(
|
|||
temperatures: torch.Tensor,
|
||||
gradient_accumulation_steps: int,
|
||||
inference_logprobs: Optional[torch.Tensor] = None,
|
||||
kl_coef: float = 0.1,
|
||||
kl_coef: float = 0.0,
|
||||
clip_eps: float = 0.2,
|
||||
use_reference_logprobs: bool = True,
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
|
|
@ -79,12 +79,12 @@ def compute_grpo_loss(
|
|||
This implements proper GRPO/PPO with:
|
||||
- Importance sampling ratio: policy(a|s) / policy_old(a|s)
|
||||
- PPO-style clipping to prevent large updates
|
||||
- KL penalty to prevent reward hacking/policy collapse
|
||||
- Optional KL-like regularization term on sampled actions
|
||||
|
||||
The loss encourages the model to:
|
||||
- Increase probability for tokens with positive advantages
|
||||
- Decrease probability for tokens with negative advantages
|
||||
- Stay close to the reference policy (inference-time policy)
|
||||
- Stay close to the rollout/reference policy on sampled tokens
|
||||
|
||||
Args:
|
||||
model: The model to compute loss for
|
||||
|
|
@ -94,7 +94,7 @@ def compute_grpo_loss(
|
|||
temperatures: Temperature values [batch, 1, 1]
|
||||
gradient_accumulation_steps: Number of accumulation steps (for scaling)
|
||||
inference_logprobs: Logprobs from inference (π_old), aligned with labels [batch, seq_len]
|
||||
kl_coef: KL penalty coefficient (beta). Higher = more conservative updates
|
||||
kl_coef: Coefficient for sampled-token KL-like regularization
|
||||
clip_eps: PPO clipping epsilon. Clips ratio to [1-eps, 1+eps]
|
||||
use_reference_logprobs: If True, use inference_logprobs as reference policy
|
||||
|
||||
|
|
@ -192,12 +192,13 @@ def compute_grpo_loss(
|
|||
# Average over tokens, then over batch
|
||||
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
|
||||
# KL penalty: encourage staying close to reference policy
|
||||
# Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4):
|
||||
# This estimator is guaranteed to be non-negative (unlike squared log-ratio).
|
||||
# KL-like sampled-token regularizer: encourages staying close to rollout policy.
|
||||
# This uses Schulman's non-negative estimator on sampled actions:
|
||||
# exp(-log_ratio) + log_ratio - 1
|
||||
# where log_ratio = log pi(a|s) - log pi_ref(a|s).
|
||||
# NOTE: this is not full-distribution KL unless evaluated over the full action space.
|
||||
if kl_coef > 0:
|
||||
# Schulman's unbiased KL estimator: (π_ref/π) - log(π_ref/π) - 1
|
||||
# = exp(-log_ratio) + log_ratio - 1
|
||||
# Schulman's sampled-token estimator.
|
||||
kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0
|
||||
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
total_loss = (
|
||||
|
|
@ -290,6 +291,7 @@ def run_training_step(
|
|||
advantage_batches: List[torch.Tensor],
|
||||
temperature_batches: List[torch.Tensor],
|
||||
config: TrainingConfig,
|
||||
step_idx: int,
|
||||
inference_logprob_batches: Optional[List[torch.Tensor]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
|
|
@ -308,7 +310,8 @@ def run_training_step(
|
|||
label_batches: List of label tensors
|
||||
advantage_batches: List of advantage tensors
|
||||
temperature_batches: List of temperature tensors
|
||||
config: Training configuration (includes kl_coef, clip_eps, use_reference_logprobs)
|
||||
config: Training configuration (includes kl_coef, clip_eps, warmup_steps)
|
||||
step_idx: Current global training step (0-based)
|
||||
inference_logprob_batches: Batched logprobs from inference (π_old), aligned with labels
|
||||
|
||||
Returns:
|
||||
|
|
@ -331,10 +334,20 @@ def run_training_step(
|
|||
all_inference_logprobs: List[torch.Tensor] = []
|
||||
|
||||
# Get GRPO hyperparameters from config
|
||||
kl_coef = getattr(config, "kl_coef", 0.1)
|
||||
kl_coef = getattr(config, "kl_coef", 0.0)
|
||||
clip_eps = getattr(config, "clip_eps", 0.2)
|
||||
use_reference_logprobs = getattr(config, "use_reference_logprobs", True)
|
||||
|
||||
# Apply linear warmup to optimizer LR for early-step stability.
|
||||
warmup_steps = max(0, int(getattr(config, "warmup_steps", 0)))
|
||||
if warmup_steps > 0 and step_idx < warmup_steps:
|
||||
warmup_scale = float(step_idx + 1) / float(max(1, warmup_steps))
|
||||
current_lr = float(config.lr) * warmup_scale
|
||||
else:
|
||||
current_lr = float(config.lr)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = current_lr
|
||||
|
||||
# Accumulate gradients over micro-batches
|
||||
num_batches = len(token_batches) if token_batches else 1
|
||||
|
||||
|
|
@ -410,6 +423,7 @@ def run_training_step(
|
|||
|
||||
result = {
|
||||
"loss": total_loss,
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm.item() if hasattr(grad_norm, "item") else grad_norm,
|
||||
"pos_logp": total_pos_logp,
|
||||
"neg_logp": total_neg_logp,
|
||||
|
|
@ -515,6 +529,7 @@ def log_metrics(
|
|||
log_dict = {
|
||||
"train/loss": metrics["loss"],
|
||||
"train/grad_norm": metrics["grad_norm"],
|
||||
"train/lr": metrics.get("lr", 0.0),
|
||||
"train/pos_logp": metrics.get("pos_logp", 0),
|
||||
"train/neg_logp": metrics.get("neg_logp", 0),
|
||||
# GRPO-specific metrics
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue