feedback fixes: shared layers + hard coded values + warmup steps

This commit is contained in:
Jai Suphavadeeprasit 2026-02-24 10:28:44 -05:00
parent e1f9b926bb
commit 624b3cdabe
9 changed files with 247 additions and 58 deletions

View file

@ -151,9 +151,10 @@ python -m example_trainer.grpo \
--atropos-url "http://localhost:8002" \
--batch-size 4 \
--gradient-accumulation-steps 4 \
--warmup-steps 20 \
--lr 1e-5 \
--training-steps 30 \
--kl-coef 0.1 \
--kl-coef 0.0 \
--clip-eps 0.2 \
--vllm-restart-interval 5 \
--save-path ./lora_checkpoints \
@ -258,7 +259,8 @@ python -m example_trainer.grpo \
--vllm-port 9001 \
--vllm-config-path /tmp/grpo_training/vllm_bridge_config.json \
--atropos-url "http://localhost:8002" \
--kl-coef 0.1 \
--warmup-steps 20 \
--kl-coef 0.0 \
--clip-eps 0.2
```
@ -307,7 +309,7 @@ Only `server_type=vllm` calls the `/generate` endpoint which returns token-level
**CRITICAL:** Without these hyperparameters, training WILL collapse (reward hacking):
```bash
--kl-coef 0.1 # Prevents policy from drifting too far from reference
--kl-coef 0.0 # Default (disable KL penalty)
--clip-eps 0.2 # Limits importance sampling ratio to [0.8, 1.2]
```
@ -328,6 +330,16 @@ Only `server_type=vllm` calls the `/generate` endpoint which returns token-level
- `mean_ratio` diverges far from 1.0
- `mean_kl` explodes (> 1.0)
### 3. Use LR Warmup for Stability
Use a short linear warmup when training from fresh runs or small batch settings:
```bash
--warmup-steps 20
```
This linearly ramps learning rate from 0 to `--lr` over the first N optimizer steps.
**Healthy training metrics:**
- `mean_ratio`: 0.8 - 1.2 (close to 1.0)
- `mean_kl`: 0.01 - 0.1
@ -355,9 +367,9 @@ The trainer supports multiple optimizer options to trade off between speed, memo
| Optimizer | GPU Memory for States | Speed | Precision | Dependencies |
|-----------|----------------------|-------|-----------|--------------|
| `adamw` | ~32GB (for 8B model) | Fastest | Full FP32 | None |
| `adamw_8bit` (default) | ~8GB | Fast | 8-bit quantized | `bitsandbytes` |
| `adafactor` | ~8GB | Fast | Full (no momentum) | `transformers` |
| `adamw` | Highest | Fastest | Full FP32 | None |
| `adamw_8bit` (default) | Lower | Fast | 8-bit quantized | `bitsandbytes` |
| `adafactor` | Lower | Fast | Full (no momentum) | `transformers` |
**Usage:**
```bash
@ -571,13 +583,15 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
| `--checkpoint-interval` | 3 | Save checkpoint every N steps (0 = final only) |
| `--batch-size` | 2 | Micro-batch size |
| `--gradient-accumulation-steps` | 32 | Effective batch = batch × accum |
| `--warmup-steps` | 0 | Linear LR warmup steps (0 disables warmup) |
| `--seq-len` | 2048 | Maximum sequence length |
| `--train-layer-indices` | None | Optional full-model layer filter for shared/legacy modes (examples: `20-31`, `0-3,28-31`) |
### GRPO Hyperparameters
| Argument | Default | Description |
|----------|---------|-------------|
| `--kl-coef` | 0.1 | KL penalty strength (higher = more conservative) |
| `--kl-coef` | 0.0 | KL penalty strength (higher = more conservative) |
| `--clip-eps` | 0.2 | PPO clipping range [1-ε, 1+ε] |
| `--lr` | 1e-5 | Learning rate (NOT --learning-rate) |
| `--no-reference-logprobs` | False | Disable GRPO reference logprobs (falls back to REINFORCE-style updates) |
@ -592,9 +606,9 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
| `--lora-target-modules` | None | Module names to apply LoRA (`None` falls back to `q_proj v_proj`) |
| `--lora-layer-indices` | None | Optional layer filter (examples: `20-31`, `0-3,28-31`) |
### LoRA Layer Index Guide (by Architecture)
### Layer Index Guide (by Architecture)
`--lora-layer-indices` is model-dependent. Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another.
Layer-index arguments are model-dependent (`--train-layer-indices` for full/shared modes, `--lora-layer-indices` for LoRA modes). Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another.
| Architecture family | Common config fields | Typical layer list path | Notes |
|---------------------|----------------------|-------------------------|-------|
@ -628,10 +642,10 @@ PY
If your model has `N` layers:
- Full layers: omit `--lora-layer-indices`
- Top 25%: `--lora-layer-indices {int(0.75*N)}-{N-1}`
- Top 50%: `--lora-layer-indices {int(0.5*N)}-{N-1}`
- Last 12 layers: `--lora-layer-indices {N-12}-{N-1}` (if `N >= 12`)
- Full layers: omit `--train-layer-indices`
- Top 25%: `--train-layer-indices {int(0.75*N)}-{N-1}`
- Top 50%: `--train-layer-indices {int(0.5*N)}-{N-1}`
- Last 12 layers: `--train-layer-indices {N-12}-{N-1}` (if `N >= 12`)
### vLLM Arguments

View file

@ -102,7 +102,7 @@ def save_checkpoint(
torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
model.config.save_pretrained(checkpoint_path)
# CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory!
# CRITICAL: Clean up the copied state_dict to free significant GPU memory.
del state_dict
import gc

View file

@ -17,7 +17,7 @@ from .config import TrainingConfig
# =============================================================================
def _parse_lora_layer_indices(value: str) -> Optional[List[int]]:
def _parse_layer_indices(value: str) -> Optional[List[int]]:
"""
Parse LoRA layer indices from comma/range syntax.
@ -110,6 +110,12 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
default=32,
help="Number of gradient accumulation steps",
)
group.add_argument(
"--warmup-steps",
type=int,
default=0,
help="Linear LR warmup steps (0 disables warmup).",
)
group.add_argument(
"--optimizer",
type=str,
@ -118,6 +124,16 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
"'adafactor' (no momentum)",
)
group.add_argument(
"--adafactor-scale-parameter",
action="store_true",
help="Enable Adafactor scale_parameter behavior (only used when --optimizer adafactor).",
)
group.add_argument(
"--adafactor-relative-step",
action="store_true",
help="Enable Adafactor relative_step behavior (only used when --optimizer adafactor).",
)
group.add_argument(
"--device",
type=str,
@ -144,8 +160,8 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None:
group.add_argument(
"--kl-coef",
type=float,
default=0.1,
help="KL divergence penalty coefficient (beta). Higher = more conservative.",
default=0.0,
help="Sampled-token KL-like regularization coefficient. Higher = more conservative.",
)
group.add_argument(
"--clip-eps",
@ -256,6 +272,15 @@ def add_mode_args(parser: argparse.ArgumentParser) -> None:
default=None,
help="Explicit path to vllm_bridge_config.json (auto-detected if not provided)",
)
group.add_argument(
"--train-layer-indices",
type=_parse_layer_indices,
default=None,
help=(
"Optional transformer layer indices to train in full/shared modes, e.g. "
"'20-31' or '0-3,28-31'. If omitted, all layers are trainable."
),
)
def add_lora_args(parser: argparse.ArgumentParser) -> None:
@ -275,7 +300,7 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
)
group.add_argument(
"--lora-layer-indices",
type=_parse_lora_layer_indices,
type=_parse_layer_indices,
default=None,
help=(
"Optional layer indices to apply LoRA to, e.g. '20-31' or "
@ -403,14 +428,17 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
batch_size=args.batch_size,
seq_len=args.seq_len,
gradient_accumulation_steps=args.gradient_accumulation_steps,
warmup_steps=getattr(args, "warmup_steps", 0),
optimizer=args.optimizer,
device=args.device,
save_path=args.save_path,
checkpoint_interval=getattr(args, "checkpoint_interval", 3),
# GRPO/PPO hyperparameters
kl_coef=getattr(args, "kl_coef", 0.1),
kl_coef=getattr(args, "kl_coef", 0.0),
clip_eps=getattr(args, "clip_eps", 0.2),
use_reference_logprobs=not getattr(args, "no_reference_logprobs", False),
adafactor_scale_parameter=getattr(args, "adafactor_scale_parameter", False),
adafactor_relative_step=getattr(args, "adafactor_relative_step", False),
# vLLM settings
vllm_restart_interval=getattr(args, "vllm_restart_interval", 3),
vllm_port=args.vllm_port,
@ -422,6 +450,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
wandb_project=args.wandb_project,
wandb_group=getattr(args, "wandb_group", None),
weight_bridge_mode=getattr(args, "weight_bridge_mode", "none"),
train_layer_indices=getattr(args, "train_layer_indices", None),
trainer_rank=getattr(args, "trainer_rank", 0),
world_size=getattr(args, "world_size", 1),
init_method=getattr(args, "init_method", "env://"),

View file

@ -32,21 +32,41 @@ class TrainingConfig(BaseModel):
gradient_accumulation_steps: int = Field(
32, description="Number of gradient accumulation steps"
)
warmup_steps: int = Field(
0,
description=(
"Number of initial optimizer steps for linear LR warmup. "
"0 disables warmup."
),
)
optimizer: Literal["adamw", "adamw_8bit", "adafactor"] = Field(
"adamw_8bit",
description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), "
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
"'adafactor' (no momentum, ~8GB GPU)",
description="Optimizer to use: 'adamw' (full precision), "
"'adamw_8bit' (8-bit states, requires bitsandbytes), "
"'adafactor' (Adafactor optimizer)",
)
adafactor_scale_parameter: bool = Field(
False,
description=(
"Whether to enable Adafactor scale_parameter behavior when using "
"optimizer='adafactor'."
),
)
adafactor_relative_step: bool = Field(
False,
description=(
"Whether to enable Adafactor relative_step behavior when using "
"optimizer='adafactor'."
),
)
# === GRPO/PPO Hyperparameters ===
kl_coef: float = Field(
0.1,
0.0,
description=(
"KL divergence penalty coefficient (beta). "
"Controls how much the policy can deviate from the reference (inference-time) policy. "
"Higher values = more conservative updates, prevents reward hacking. "
"Set to 0 to disable KL penalty (not recommended)."
"Coefficient for sampled-token KL-like regularization against rollout/reference "
"logprobs. Higher values make updates more conservative. "
"Set to 0 to disable this term."
),
)
clip_eps: float = Field(
@ -121,6 +141,13 @@ class TrainingConfig(BaseModel):
"'none': legacy mode, restart vLLM with new checkpoint files."
),
)
train_layer_indices: Optional[List[int]] = Field(
None,
description=(
"Optional list of transformer layer indices to train in shared/legacy "
"full-model modes. If None, all layers are trainable."
),
)
# === Distributed Training Configuration ===
trainer_rank: int = Field(

View file

@ -10,6 +10,7 @@ Handles:
import base64
import json
import os
import re
from typing import Dict, Optional, Tuple
import torch
@ -119,6 +120,7 @@ def load_model_and_tokenizer(
if model is not None:
print("[Setup] ✓ Single-copy mode active - using vLLM's tensors directly!")
_apply_train_layer_filter(model, config.train_layer_indices)
# Enable gradient checkpointing to save memory (was missing before!)
_setup_gradient_checkpointing(model, config)
model.train()
@ -141,6 +143,7 @@ def load_model_and_tokenizer(
print("[Setup] Loading model for legacy mode...")
model = _load_model_with_attention(config.model_name)
model.to(config.device)
_apply_train_layer_filter(model, config.train_layer_indices)
# Enable gradient checkpointing
_setup_gradient_checkpointing(model, config)
@ -242,6 +245,59 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
return model
def _apply_train_layer_filter(
model: torch.nn.Module, layer_indices: Optional[list[int]]
) -> None:
"""
Freeze all parameters except selected transformer block indices.
Applies to full-model modes (shared_vllm / legacy), not LoRA.
"""
if layer_indices is None:
return
num_hidden_layers = getattr(model.config, "num_hidden_layers", None)
if num_hidden_layers is None:
num_hidden_layers = getattr(model.config, "n_layer", None)
if num_hidden_layers is None:
raise RuntimeError(
"Model config does not expose num_hidden_layers or n_layer; "
"cannot validate --train-layer-indices for this architecture."
)
invalid = [idx for idx in layer_indices if idx >= num_hidden_layers]
if invalid:
raise ValueError(
f"Invalid --train-layer-indices {invalid} for model with "
f"{num_hidden_layers} layers (valid range: 0-{num_hidden_layers - 1})"
)
allowed = set(layer_indices)
layer_pattern = re.compile(r"\.layers\.(\d+)\.")
trainable_params = 0
total_params = 0
for name, param in model.named_parameters():
match = layer_pattern.search(name)
should_train = bool(match and int(match.group(1)) in allowed)
param.requires_grad_(should_train)
total_params += param.numel()
if should_train:
trainable_params += param.numel()
if trainable_params == 0:
raise RuntimeError(
"--train-layer-indices did not match any trainable parameters. "
"Check architecture naming and selected indices."
)
pct = 100.0 * trainable_params / max(total_params, 1)
print(
f"[Setup] Training only transformer layers {sorted(allowed)} "
f"({trainable_params}/{total_params} params, {pct:.2f}%)"
)
def _setup_gradient_checkpointing(
model: torch.nn.Module, config: TrainingConfig
) -> None:

View file

@ -194,6 +194,7 @@ def main():
batch_size=args.batch_size,
seq_len=args.seq_len,
gradient_accumulation_steps=args.gradient_accumulation_steps,
warmup_steps=getattr(args, "warmup_steps", 0),
optimizer=args.optimizer,
device="cuda:0", # Always 0 since we set CUDA_VISIBLE_DEVICES
save_path=args.save_path,

View file

@ -3,8 +3,8 @@ set -euo pipefail
# Runs three GSM8K test trainings with separate infra/ports:
# 1) shared_vllm
# 2) lora_only (+ layer filtering support)
# 3) lora_restart (+ layer filtering support)
# 2) lora_only
# 3) lora_restart
#
# Usage:
# chmod +x example_trainer/run_gsm8k_lora_matrix.sh
@ -12,8 +12,11 @@ set -euo pipefail
#
# Optional environment overrides:
# MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
# TRAINING_STEPS=10
# LORA_LAYER_INDICES="20-31"
# TRAINING_STEPS=30
# WARMUP_STEPS=5
# MATRIX_TARGETED=1 # auto-enable layer targeting defaults for smoke tests
# SHARED_LAYER_INDICES="0-3" # overrides MATRIX_TARGETED default
# LORA_LAYER_INDICES="0-3" # overrides MATRIX_TARGETED default
# WANDB_PROJECT="gsm8k-grpo-smoke"
# WANDB_GROUP="gsm8k-$(date +%Y%m%d-%H%M%S)"
# START_API_PORT=8002
@ -37,11 +40,12 @@ cd "$ROOT_DIR"
PYTHON_BIN="${PYTHON_BIN:-python3}"
MODEL_NAME="${MODEL_NAME:-NousResearch/Hermes-3-Llama-3.1-8B}"
TRAINING_STEPS="${TRAINING_STEPS:-10}"
TRAINING_STEPS="${TRAINING_STEPS:-30}"
BATCH_SIZE="${BATCH_SIZE:-4}"
GRAD_ACCUM="${GRAD_ACCUM:-4}"
LR="${LR:-1e-5}"
KL_COEF="${KL_COEF:-0.1}"
WARMUP_STEPS="${WARMUP_STEPS:-5}"
KL_COEF="${KL_COEF:-0.0}"
CLIP_EPS="${CLIP_EPS:-0.2}"
GPU_MEMORY_UTILIZATION="${GPU_MEMORY_UTILIZATION:-0.45}"
MAX_MODEL_LEN="${MAX_MODEL_LEN:-4096}"
@ -50,7 +54,13 @@ LORA_R="${LORA_R:-16}"
LORA_ALPHA="${LORA_ALPHA:-32}"
LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
LORA_TARGET_MODULES="${LORA_TARGET_MODULES:-q_proj v_proj}"
MATRIX_TARGETED="${MATRIX_TARGETED:-1}"
SHARED_LAYER_INDICES="${SHARED_LAYER_INDICES:-}"
LORA_LAYER_INDICES="${LORA_LAYER_INDICES:-}"
if [[ "$MATRIX_TARGETED" == "1" ]]; then
SHARED_LAYER_INDICES="${SHARED_LAYER_INDICES:-0-3}"
LORA_LAYER_INDICES="${LORA_LAYER_INDICES:-0-3}"
fi
WANDB_PROJECT="${WANDB_PROJECT:-gsm8k-grpo-smoke}"
WANDB_GROUP="${WANDB_GROUP:-gsm8k-$(date +%Y%m%d-%H%M%S)}"
START_API_PORT="${START_API_PORT:-8002}"
@ -153,6 +163,12 @@ cleanup_run() {
run_ports=()
}
add_shared_layer_flag() {
if [[ -n "$SHARED_LAYER_INDICES" ]]; then
echo "--train-layer-indices" "$SHARED_LAYER_INDICES"
fi
}
add_lora_layer_flag() {
if [[ -n "$LORA_LAYER_INDICES" ]]; then
echo "--lora-layer-indices" "$LORA_LAYER_INDICES"
@ -165,6 +181,7 @@ common_trainer_flags() {
--training-steps "$TRAINING_STEPS" \
--batch-size "$BATCH_SIZE" \
--gradient-accumulation-steps "$GRAD_ACCUM" \
--warmup-steps "$WARMUP_STEPS" \
--lr "$LR" \
--kl-coef "$KL_COEF" \
--clip-eps "$CLIP_EPS" \
@ -252,7 +269,8 @@ run_shared_vllm() {
--save-path "$save_dir" \
--vllm-port "$vllm_port" \
--vllm-config-path "${bridge_dir}/vllm_bridge_config.json" \
--atropos-url "http://localhost:${api_port}"
--atropos-url "http://localhost:${api_port}" \
$(add_shared_layer_flag)
printf '\n'
log "[DRY RUN] trainer log path: $mode_dir/trainer.log"
else
@ -263,7 +281,8 @@ run_shared_vllm() {
--save-path "$save_dir" \
--vllm-port "$vllm_port" \
--vllm-config-path "${bridge_dir}/vllm_bridge_config.json" \
--atropos-url "http://localhost:${api_port}" | tee "$mode_dir/trainer.log"
--atropos-url "http://localhost:${api_port}" \
$(add_shared_layer_flag) | tee "$mode_dir/trainer.log"
fi
cleanup_run
@ -419,6 +438,8 @@ log "Model: $MODEL_NAME"
log "W&B project/group: $WANDB_PROJECT / $WANDB_GROUP"
log "Dry run mode: $DRY_RUN"
log "Output base directory (logs + saves): $OUTPUT_BASE_DIR"
log "Warmup steps: $WARMUP_STEPS"
log "Targeted-layer matrix profile: $MATRIX_TARGETED"
log "Port plan:"
log " shared_vllm: run-api=${SHARED_API_PORT}, vllm=${SHARED_VLLM_PORT}"
log " lora_only: run-api=${LORA_ONLY_API_PORT}, vllm=${LORA_ONLY_VLLM_PORT}"
@ -427,6 +448,11 @@ log "GPU plan:"
log " shared_vllm: trainer+vllm on GPU ${SHARED_GPU} (required for shared weights)"
log " lora_only: trainer GPU ${LORA_ONLY_TRAINER_GPU}, vllm GPU ${LORA_ONLY_VLLM_GPU}"
log " lora_restart: trainer GPU ${LORA_RESTART_TRAINER_GPU}, vllm GPU ${LORA_RESTART_VLLM_GPU}"
if [[ -n "$SHARED_LAYER_INDICES" ]]; then
log "Shared-model train layer indices: $SHARED_LAYER_INDICES"
else
log "Shared-model train layer indices: all layers"
fi
if [[ -n "$LORA_LAYER_INDICES" ]]; then
log "LoRA layer indices: $LORA_LAYER_INDICES"
else

View file

@ -12,6 +12,7 @@ import os
import subprocess
import sys
import time
import logging
from typing import Iterable, Optional
import requests
@ -21,16 +22,22 @@ from torch.optim import AdamW
from .api import check_atropos_api, register_trainer
logger = logging.getLogger(__name__)
def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
"""
Create optimizer based on config.optimizer setting.
Options:
- 'adamw': Standard AdamW (full precision, ~32GB GPU for 8B model)
- 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes)
- 'adafactor': Adafactor without momentum (~8GB GPU, no extra dependencies)
- 'adamw': Standard AdamW
- 'adamw_8bit': 8-bit AdamW from bitsandbytes (requires bitsandbytes)
- 'adafactor': Adafactor optimizer (requires transformers)
"""
return create_optimizer_for_params(model.parameters(), config)
trainable_params = [p for p in model.parameters() if p.requires_grad]
if not trainable_params:
raise RuntimeError("No trainable parameters found for optimizer creation.")
return create_optimizer_for_params(trainable_params, config)
def create_optimizer_for_params(
@ -38,36 +45,46 @@ def create_optimizer_for_params(
) -> torch.optim.Optimizer:
"""Create optimizer for a specific parameter iterable."""
params = list(params)
if not params:
raise RuntimeError("Optimizer received an empty parameter list.")
if config.optimizer == "adamw_8bit":
try:
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(params, lr=config.lr)
print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)")
logger.info("[Setup] Using 8-bit AdamW optimizer")
return optimizer
except ImportError:
print("[Setup] WARNING: bitsandbytes not installed, falling back to AdamW")
print("[Setup] Install with: pip install bitsandbytes")
logger.warning(
"[Setup] bitsandbytes not installed, falling back to AdamW"
)
logger.info("[Setup] Install with: pip install bitsandbytes")
if config.optimizer == "adafactor":
try:
from transformers.optimization import Adafactor
scale_parameter = getattr(config, "adafactor_scale_parameter", False)
relative_step = getattr(config, "adafactor_relative_step", False)
optimizer = Adafactor(
params,
lr=config.lr,
scale_parameter=False,
relative_step=False,
scale_parameter=scale_parameter,
relative_step=relative_step,
)
logger.info(
"[Setup] Using Adafactor optimizer (scale_parameter=%s, relative_step=%s)",
scale_parameter,
relative_step,
)
print("[Setup] Using Adafactor (no momentum, saves ~24GB)")
return optimizer
except ImportError:
print("[Setup] WARNING: transformers Adafactor not available, using AdamW")
logger.warning("[Setup] transformers Adafactor unavailable, using AdamW")
# Default: standard AdamW
optimizer = AdamW(params, lr=config.lr)
print("[Setup] Using standard AdamW (requires ~32GB for optimizer states)")
logger.info("[Setup] Using standard AdamW optimizer")
return optimizer
@ -176,6 +193,7 @@ def train_legacy(config: TrainingConfig):
advantage_batches,
temperature_batches,
config,
step_idx=step,
inference_logprob_batches=inference_logprob_batches,
)
step_time = time.time() - step_start
@ -322,6 +340,7 @@ def train_shared_vllm(config: TrainingConfig):
advantage_batches,
temperature_batches,
config,
step_idx=step,
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
)
step_time = time.time() - step_start
@ -481,6 +500,7 @@ def train_lora(config: TrainingConfig):
advantage_batches,
temperature_batches,
config,
step_idx=step,
inference_logprob_batches=inference_logprob_batches,
)
step_time = time.time() - step_start
@ -702,6 +722,7 @@ def train_lora_restart(config: TrainingConfig):
advantage_batches,
temperature_batches,
config,
step_idx=step,
inference_logprob_batches=inference_logprob_batches,
)
step_time = time.time() - step_start

View file

@ -69,7 +69,7 @@ def compute_grpo_loss(
temperatures: torch.Tensor,
gradient_accumulation_steps: int,
inference_logprobs: Optional[torch.Tensor] = None,
kl_coef: float = 0.1,
kl_coef: float = 0.0,
clip_eps: float = 0.2,
use_reference_logprobs: bool = True,
) -> Tuple[torch.Tensor, dict]:
@ -79,12 +79,12 @@ def compute_grpo_loss(
This implements proper GRPO/PPO with:
- Importance sampling ratio: policy(a|s) / policy_old(a|s)
- PPO-style clipping to prevent large updates
- KL penalty to prevent reward hacking/policy collapse
- Optional KL-like regularization term on sampled actions
The loss encourages the model to:
- Increase probability for tokens with positive advantages
- Decrease probability for tokens with negative advantages
- Stay close to the reference policy (inference-time policy)
- Stay close to the rollout/reference policy on sampled tokens
Args:
model: The model to compute loss for
@ -94,7 +94,7 @@ def compute_grpo_loss(
temperatures: Temperature values [batch, 1, 1]
gradient_accumulation_steps: Number of accumulation steps (for scaling)
inference_logprobs: Logprobs from inference (π_old), aligned with labels [batch, seq_len]
kl_coef: KL penalty coefficient (beta). Higher = more conservative updates
kl_coef: Coefficient for sampled-token KL-like regularization
clip_eps: PPO clipping epsilon. Clips ratio to [1-eps, 1+eps]
use_reference_logprobs: If True, use inference_logprobs as reference policy
@ -192,12 +192,13 @@ def compute_grpo_loss(
# Average over tokens, then over batch
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
# KL penalty: encourage staying close to reference policy
# Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4):
# This estimator is guaranteed to be non-negative (unlike squared log-ratio).
# KL-like sampled-token regularizer: encourages staying close to rollout policy.
# This uses Schulman's non-negative estimator on sampled actions:
# exp(-log_ratio) + log_ratio - 1
# where log_ratio = log pi(a|s) - log pi_ref(a|s).
# NOTE: this is not full-distribution KL unless evaluated over the full action space.
if kl_coef > 0:
# Schulman's unbiased KL estimator: (π_ref/π) - log(π_ref/π) - 1
# = exp(-log_ratio) + log_ratio - 1
# Schulman's sampled-token estimator.
kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
total_loss = (
@ -290,6 +291,7 @@ def run_training_step(
advantage_batches: List[torch.Tensor],
temperature_batches: List[torch.Tensor],
config: TrainingConfig,
step_idx: int,
inference_logprob_batches: Optional[List[torch.Tensor]] = None,
) -> dict:
"""
@ -308,7 +310,8 @@ def run_training_step(
label_batches: List of label tensors
advantage_batches: List of advantage tensors
temperature_batches: List of temperature tensors
config: Training configuration (includes kl_coef, clip_eps, use_reference_logprobs)
config: Training configuration (includes kl_coef, clip_eps, warmup_steps)
step_idx: Current global training step (0-based)
inference_logprob_batches: Batched logprobs from inference (π_old), aligned with labels
Returns:
@ -331,10 +334,20 @@ def run_training_step(
all_inference_logprobs: List[torch.Tensor] = []
# Get GRPO hyperparameters from config
kl_coef = getattr(config, "kl_coef", 0.1)
kl_coef = getattr(config, "kl_coef", 0.0)
clip_eps = getattr(config, "clip_eps", 0.2)
use_reference_logprobs = getattr(config, "use_reference_logprobs", True)
# Apply linear warmup to optimizer LR for early-step stability.
warmup_steps = max(0, int(getattr(config, "warmup_steps", 0)))
if warmup_steps > 0 and step_idx < warmup_steps:
warmup_scale = float(step_idx + 1) / float(max(1, warmup_steps))
current_lr = float(config.lr) * warmup_scale
else:
current_lr = float(config.lr)
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
# Accumulate gradients over micro-batches
num_batches = len(token_batches) if token_batches else 1
@ -410,6 +423,7 @@ def run_training_step(
result = {
"loss": total_loss,
"lr": current_lr,
"grad_norm": grad_norm.item() if hasattr(grad_norm, "item") else grad_norm,
"pos_logp": total_pos_logp,
"neg_logp": total_neg_logp,
@ -515,6 +529,7 @@ def log_metrics(
log_dict = {
"train/loss": metrics["loss"],
"train/grad_norm": metrics["grad_norm"],
"train/lr": metrics.get("lr", 0.0),
"train/pos_logp": metrics.get("pos_logp", 0),
"train/neg_logp": metrics.get("neg_logp", 0),
# GRPO-specific metrics