diff --git a/example_trainer/README.md b/example_trainer/README.md index 9be1cf66..52f10ecf 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -151,9 +151,10 @@ python -m example_trainer.grpo \ --atropos-url "http://localhost:8002" \ --batch-size 4 \ --gradient-accumulation-steps 4 \ + --warmup-steps 20 \ --lr 1e-5 \ --training-steps 30 \ - --kl-coef 0.1 \ + --kl-coef 0.0 \ --clip-eps 0.2 \ --vllm-restart-interval 5 \ --save-path ./lora_checkpoints \ @@ -258,7 +259,8 @@ python -m example_trainer.grpo \ --vllm-port 9001 \ --vllm-config-path /tmp/grpo_training/vllm_bridge_config.json \ --atropos-url "http://localhost:8002" \ - --kl-coef 0.1 \ + --warmup-steps 20 \ + --kl-coef 0.0 \ --clip-eps 0.2 ``` @@ -307,7 +309,7 @@ Only `server_type=vllm` calls the `/generate` endpoint which returns token-level **CRITICAL:** Without these hyperparameters, training WILL collapse (reward hacking): ```bash ---kl-coef 0.1 # Prevents policy from drifting too far from reference +--kl-coef 0.0 # Default (disable KL penalty) --clip-eps 0.2 # Limits importance sampling ratio to [0.8, 1.2] ``` @@ -328,6 +330,16 @@ Only `server_type=vllm` calls the `/generate` endpoint which returns token-level - `mean_ratio` diverges far from 1.0 - `mean_kl` explodes (> 1.0) +### 3. Use LR Warmup for Stability + +Use a short linear warmup when training from fresh runs or small batch settings: + +```bash +--warmup-steps 20 +``` + +This linearly ramps learning rate from 0 to `--lr` over the first N optimizer steps. + **Healthy training metrics:** - `mean_ratio`: 0.8 - 1.2 (close to 1.0) - `mean_kl`: 0.01 - 0.1 @@ -355,9 +367,9 @@ The trainer supports multiple optimizer options to trade off between speed, memo | Optimizer | GPU Memory for States | Speed | Precision | Dependencies | |-----------|----------------------|-------|-----------|--------------| -| `adamw` | ~32GB (for 8B model) | Fastest | Full FP32 | None | -| `adamw_8bit` (default) | ~8GB | Fast | 8-bit quantized | `bitsandbytes` | -| `adafactor` | ~8GB | Fast | Full (no momentum) | `transformers` | +| `adamw` | Highest | Fastest | Full FP32 | None | +| `adamw_8bit` (default) | Lower | Fast | 8-bit quantized | `bitsandbytes` | +| `adafactor` | Lower | Fast | Full (no momentum) | `transformers` | **Usage:** ```bash @@ -571,13 +583,15 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands | `--checkpoint-interval` | 3 | Save checkpoint every N steps (0 = final only) | | `--batch-size` | 2 | Micro-batch size | | `--gradient-accumulation-steps` | 32 | Effective batch = batch × accum | +| `--warmup-steps` | 0 | Linear LR warmup steps (0 disables warmup) | | `--seq-len` | 2048 | Maximum sequence length | +| `--train-layer-indices` | None | Optional full-model layer filter for shared/legacy modes (examples: `20-31`, `0-3,28-31`) | ### GRPO Hyperparameters | Argument | Default | Description | |----------|---------|-------------| -| `--kl-coef` | 0.1 | KL penalty strength (higher = more conservative) | +| `--kl-coef` | 0.0 | KL penalty strength (higher = more conservative) | | `--clip-eps` | 0.2 | PPO clipping range [1-ε, 1+ε] | | `--lr` | 1e-5 | Learning rate (NOT --learning-rate) | | `--no-reference-logprobs` | False | Disable GRPO reference logprobs (falls back to REINFORCE-style updates) | @@ -592,9 +606,9 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands | `--lora-target-modules` | None | Module names to apply LoRA (`None` falls back to `q_proj v_proj`) | | `--lora-layer-indices` | None | Optional layer filter (examples: `20-31`, `0-3,28-31`) | -### LoRA Layer Index Guide (by Architecture) +### Layer Index Guide (by Architecture) -`--lora-layer-indices` is model-dependent. Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another. +Layer-index arguments are model-dependent (`--train-layer-indices` for full/shared modes, `--lora-layer-indices` for LoRA modes). Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another. | Architecture family | Common config fields | Typical layer list path | Notes | |---------------------|----------------------|-------------------------|-------| @@ -628,10 +642,10 @@ PY If your model has `N` layers: -- Full layers: omit `--lora-layer-indices` -- Top 25%: `--lora-layer-indices {int(0.75*N)}-{N-1}` -- Top 50%: `--lora-layer-indices {int(0.5*N)}-{N-1}` -- Last 12 layers: `--lora-layer-indices {N-12}-{N-1}` (if `N >= 12`) +- Full layers: omit `--train-layer-indices` +- Top 25%: `--train-layer-indices {int(0.75*N)}-{N-1}` +- Top 50%: `--train-layer-indices {int(0.5*N)}-{N-1}` +- Last 12 layers: `--train-layer-indices {N-12}-{N-1}` (if `N >= 12`) ### vLLM Arguments diff --git a/example_trainer/checkpointing.py b/example_trainer/checkpointing.py index b5d60bbe..0d917d71 100644 --- a/example_trainer/checkpointing.py +++ b/example_trainer/checkpointing.py @@ -102,7 +102,7 @@ def save_checkpoint( torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin")) model.config.save_pretrained(checkpoint_path) - # CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory! + # CRITICAL: Clean up the copied state_dict to free significant GPU memory. del state_dict import gc diff --git a/example_trainer/cli.py b/example_trainer/cli.py index 4b602277..cb5e364b 100644 --- a/example_trainer/cli.py +++ b/example_trainer/cli.py @@ -17,7 +17,7 @@ from .config import TrainingConfig # ============================================================================= -def _parse_lora_layer_indices(value: str) -> Optional[List[int]]: +def _parse_layer_indices(value: str) -> Optional[List[int]]: """ Parse LoRA layer indices from comma/range syntax. @@ -110,6 +110,12 @@ def add_training_args(parser: argparse.ArgumentParser) -> None: default=32, help="Number of gradient accumulation steps", ) + group.add_argument( + "--warmup-steps", + type=int, + default=0, + help="Linear LR warmup steps (0 disables warmup).", + ) group.add_argument( "--optimizer", type=str, @@ -118,6 +124,16 @@ def add_training_args(parser: argparse.ArgumentParser) -> None: help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), " "'adafactor' (no momentum)", ) + group.add_argument( + "--adafactor-scale-parameter", + action="store_true", + help="Enable Adafactor scale_parameter behavior (only used when --optimizer adafactor).", + ) + group.add_argument( + "--adafactor-relative-step", + action="store_true", + help="Enable Adafactor relative_step behavior (only used when --optimizer adafactor).", + ) group.add_argument( "--device", type=str, @@ -144,8 +160,8 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None: group.add_argument( "--kl-coef", type=float, - default=0.1, - help="KL divergence penalty coefficient (beta). Higher = more conservative.", + default=0.0, + help="Sampled-token KL-like regularization coefficient. Higher = more conservative.", ) group.add_argument( "--clip-eps", @@ -256,6 +272,15 @@ def add_mode_args(parser: argparse.ArgumentParser) -> None: default=None, help="Explicit path to vllm_bridge_config.json (auto-detected if not provided)", ) + group.add_argument( + "--train-layer-indices", + type=_parse_layer_indices, + default=None, + help=( + "Optional transformer layer indices to train in full/shared modes, e.g. " + "'20-31' or '0-3,28-31'. If omitted, all layers are trainable." + ), + ) def add_lora_args(parser: argparse.ArgumentParser) -> None: @@ -275,7 +300,7 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None: ) group.add_argument( "--lora-layer-indices", - type=_parse_lora_layer_indices, + type=_parse_layer_indices, default=None, help=( "Optional layer indices to apply LoRA to, e.g. '20-31' or " @@ -403,14 +428,17 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: batch_size=args.batch_size, seq_len=args.seq_len, gradient_accumulation_steps=args.gradient_accumulation_steps, + warmup_steps=getattr(args, "warmup_steps", 0), optimizer=args.optimizer, device=args.device, save_path=args.save_path, checkpoint_interval=getattr(args, "checkpoint_interval", 3), # GRPO/PPO hyperparameters - kl_coef=getattr(args, "kl_coef", 0.1), + kl_coef=getattr(args, "kl_coef", 0.0), clip_eps=getattr(args, "clip_eps", 0.2), use_reference_logprobs=not getattr(args, "no_reference_logprobs", False), + adafactor_scale_parameter=getattr(args, "adafactor_scale_parameter", False), + adafactor_relative_step=getattr(args, "adafactor_relative_step", False), # vLLM settings vllm_restart_interval=getattr(args, "vllm_restart_interval", 3), vllm_port=args.vllm_port, @@ -422,6 +450,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: wandb_project=args.wandb_project, wandb_group=getattr(args, "wandb_group", None), weight_bridge_mode=getattr(args, "weight_bridge_mode", "none"), + train_layer_indices=getattr(args, "train_layer_indices", None), trainer_rank=getattr(args, "trainer_rank", 0), world_size=getattr(args, "world_size", 1), init_method=getattr(args, "init_method", "env://"), diff --git a/example_trainer/config.py b/example_trainer/config.py index 7acc3494..cce035b7 100644 --- a/example_trainer/config.py +++ b/example_trainer/config.py @@ -32,21 +32,41 @@ class TrainingConfig(BaseModel): gradient_accumulation_steps: int = Field( 32, description="Number of gradient accumulation steps" ) + warmup_steps: int = Field( + 0, + description=( + "Number of initial optimizer steps for linear LR warmup. " + "0 disables warmup." + ), + ) optimizer: Literal["adamw", "adamw_8bit", "adafactor"] = Field( "adamw_8bit", - description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), " - "'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), " - "'adafactor' (no momentum, ~8GB GPU)", + description="Optimizer to use: 'adamw' (full precision), " + "'adamw_8bit' (8-bit states, requires bitsandbytes), " + "'adafactor' (Adafactor optimizer)", + ) + adafactor_scale_parameter: bool = Field( + False, + description=( + "Whether to enable Adafactor scale_parameter behavior when using " + "optimizer='adafactor'." + ), + ) + adafactor_relative_step: bool = Field( + False, + description=( + "Whether to enable Adafactor relative_step behavior when using " + "optimizer='adafactor'." + ), ) # === GRPO/PPO Hyperparameters === kl_coef: float = Field( - 0.1, + 0.0, description=( - "KL divergence penalty coefficient (beta). " - "Controls how much the policy can deviate from the reference (inference-time) policy. " - "Higher values = more conservative updates, prevents reward hacking. " - "Set to 0 to disable KL penalty (not recommended)." + "Coefficient for sampled-token KL-like regularization against rollout/reference " + "logprobs. Higher values make updates more conservative. " + "Set to 0 to disable this term." ), ) clip_eps: float = Field( @@ -121,6 +141,13 @@ class TrainingConfig(BaseModel): "'none': legacy mode, restart vLLM with new checkpoint files." ), ) + train_layer_indices: Optional[List[int]] = Field( + None, + description=( + "Optional list of transformer layer indices to train in shared/legacy " + "full-model modes. If None, all layers are trainable." + ), + ) # === Distributed Training Configuration === trainer_rank: int = Field( diff --git a/example_trainer/model.py b/example_trainer/model.py index 315adc20..0cf0e846 100644 --- a/example_trainer/model.py +++ b/example_trainer/model.py @@ -10,6 +10,7 @@ Handles: import base64 import json import os +import re from typing import Dict, Optional, Tuple import torch @@ -119,6 +120,7 @@ def load_model_and_tokenizer( if model is not None: print("[Setup] ✓ Single-copy mode active - using vLLM's tensors directly!") + _apply_train_layer_filter(model, config.train_layer_indices) # Enable gradient checkpointing to save memory (was missing before!) _setup_gradient_checkpointing(model, config) model.train() @@ -141,6 +143,7 @@ def load_model_and_tokenizer( print("[Setup] Loading model for legacy mode...") model = _load_model_with_attention(config.model_name) model.to(config.device) + _apply_train_layer_filter(model, config.train_layer_indices) # Enable gradient checkpointing _setup_gradient_checkpointing(model, config) @@ -242,6 +245,59 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module: return model +def _apply_train_layer_filter( + model: torch.nn.Module, layer_indices: Optional[list[int]] +) -> None: + """ + Freeze all parameters except selected transformer block indices. + + Applies to full-model modes (shared_vllm / legacy), not LoRA. + """ + if layer_indices is None: + return + + num_hidden_layers = getattr(model.config, "num_hidden_layers", None) + if num_hidden_layers is None: + num_hidden_layers = getattr(model.config, "n_layer", None) + if num_hidden_layers is None: + raise RuntimeError( + "Model config does not expose num_hidden_layers or n_layer; " + "cannot validate --train-layer-indices for this architecture." + ) + + invalid = [idx for idx in layer_indices if idx >= num_hidden_layers] + if invalid: + raise ValueError( + f"Invalid --train-layer-indices {invalid} for model with " + f"{num_hidden_layers} layers (valid range: 0-{num_hidden_layers - 1})" + ) + + allowed = set(layer_indices) + layer_pattern = re.compile(r"\.layers\.(\d+)\.") + trainable_params = 0 + total_params = 0 + + for name, param in model.named_parameters(): + match = layer_pattern.search(name) + should_train = bool(match and int(match.group(1)) in allowed) + param.requires_grad_(should_train) + total_params += param.numel() + if should_train: + trainable_params += param.numel() + + if trainable_params == 0: + raise RuntimeError( + "--train-layer-indices did not match any trainable parameters. " + "Check architecture naming and selected indices." + ) + + pct = 100.0 * trainable_params / max(total_params, 1) + print( + f"[Setup] Training only transformer layers {sorted(allowed)} " + f"({trainable_params}/{total_params} params, {pct:.2f}%)" + ) + + def _setup_gradient_checkpointing( model: torch.nn.Module, config: TrainingConfig ) -> None: diff --git a/example_trainer/run.py b/example_trainer/run.py index ff60d0ba..ca6bf7ab 100644 --- a/example_trainer/run.py +++ b/example_trainer/run.py @@ -194,6 +194,7 @@ def main(): batch_size=args.batch_size, seq_len=args.seq_len, gradient_accumulation_steps=args.gradient_accumulation_steps, + warmup_steps=getattr(args, "warmup_steps", 0), optimizer=args.optimizer, device="cuda:0", # Always 0 since we set CUDA_VISIBLE_DEVICES save_path=args.save_path, diff --git a/example_trainer/run_gsm8k_lora_matrix.sh b/example_trainer/run_gsm8k_lora_matrix.sh index 3f110c13..a11dad83 100755 --- a/example_trainer/run_gsm8k_lora_matrix.sh +++ b/example_trainer/run_gsm8k_lora_matrix.sh @@ -3,8 +3,8 @@ set -euo pipefail # Runs three GSM8K test trainings with separate infra/ports: # 1) shared_vllm -# 2) lora_only (+ layer filtering support) -# 3) lora_restart (+ layer filtering support) +# 2) lora_only +# 3) lora_restart # # Usage: # chmod +x example_trainer/run_gsm8k_lora_matrix.sh @@ -12,8 +12,11 @@ set -euo pipefail # # Optional environment overrides: # MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B" -# TRAINING_STEPS=10 -# LORA_LAYER_INDICES="20-31" +# TRAINING_STEPS=30 +# WARMUP_STEPS=5 +# MATRIX_TARGETED=1 # auto-enable layer targeting defaults for smoke tests +# SHARED_LAYER_INDICES="0-3" # overrides MATRIX_TARGETED default +# LORA_LAYER_INDICES="0-3" # overrides MATRIX_TARGETED default # WANDB_PROJECT="gsm8k-grpo-smoke" # WANDB_GROUP="gsm8k-$(date +%Y%m%d-%H%M%S)" # START_API_PORT=8002 @@ -37,11 +40,12 @@ cd "$ROOT_DIR" PYTHON_BIN="${PYTHON_BIN:-python3}" MODEL_NAME="${MODEL_NAME:-NousResearch/Hermes-3-Llama-3.1-8B}" -TRAINING_STEPS="${TRAINING_STEPS:-10}" +TRAINING_STEPS="${TRAINING_STEPS:-30}" BATCH_SIZE="${BATCH_SIZE:-4}" GRAD_ACCUM="${GRAD_ACCUM:-4}" LR="${LR:-1e-5}" -KL_COEF="${KL_COEF:-0.1}" +WARMUP_STEPS="${WARMUP_STEPS:-5}" +KL_COEF="${KL_COEF:-0.0}" CLIP_EPS="${CLIP_EPS:-0.2}" GPU_MEMORY_UTILIZATION="${GPU_MEMORY_UTILIZATION:-0.45}" MAX_MODEL_LEN="${MAX_MODEL_LEN:-4096}" @@ -50,7 +54,13 @@ LORA_R="${LORA_R:-16}" LORA_ALPHA="${LORA_ALPHA:-32}" LORA_DROPOUT="${LORA_DROPOUT:-0.05}" LORA_TARGET_MODULES="${LORA_TARGET_MODULES:-q_proj v_proj}" +MATRIX_TARGETED="${MATRIX_TARGETED:-1}" +SHARED_LAYER_INDICES="${SHARED_LAYER_INDICES:-}" LORA_LAYER_INDICES="${LORA_LAYER_INDICES:-}" +if [[ "$MATRIX_TARGETED" == "1" ]]; then + SHARED_LAYER_INDICES="${SHARED_LAYER_INDICES:-0-3}" + LORA_LAYER_INDICES="${LORA_LAYER_INDICES:-0-3}" +fi WANDB_PROJECT="${WANDB_PROJECT:-gsm8k-grpo-smoke}" WANDB_GROUP="${WANDB_GROUP:-gsm8k-$(date +%Y%m%d-%H%M%S)}" START_API_PORT="${START_API_PORT:-8002}" @@ -153,6 +163,12 @@ cleanup_run() { run_ports=() } +add_shared_layer_flag() { + if [[ -n "$SHARED_LAYER_INDICES" ]]; then + echo "--train-layer-indices" "$SHARED_LAYER_INDICES" + fi +} + add_lora_layer_flag() { if [[ -n "$LORA_LAYER_INDICES" ]]; then echo "--lora-layer-indices" "$LORA_LAYER_INDICES" @@ -165,6 +181,7 @@ common_trainer_flags() { --training-steps "$TRAINING_STEPS" \ --batch-size "$BATCH_SIZE" \ --gradient-accumulation-steps "$GRAD_ACCUM" \ + --warmup-steps "$WARMUP_STEPS" \ --lr "$LR" \ --kl-coef "$KL_COEF" \ --clip-eps "$CLIP_EPS" \ @@ -252,7 +269,8 @@ run_shared_vllm() { --save-path "$save_dir" \ --vllm-port "$vllm_port" \ --vllm-config-path "${bridge_dir}/vllm_bridge_config.json" \ - --atropos-url "http://localhost:${api_port}" + --atropos-url "http://localhost:${api_port}" \ + $(add_shared_layer_flag) printf '\n' log "[DRY RUN] trainer log path: $mode_dir/trainer.log" else @@ -263,7 +281,8 @@ run_shared_vllm() { --save-path "$save_dir" \ --vllm-port "$vllm_port" \ --vllm-config-path "${bridge_dir}/vllm_bridge_config.json" \ - --atropos-url "http://localhost:${api_port}" | tee "$mode_dir/trainer.log" + --atropos-url "http://localhost:${api_port}" \ + $(add_shared_layer_flag) | tee "$mode_dir/trainer.log" fi cleanup_run @@ -419,6 +438,8 @@ log "Model: $MODEL_NAME" log "W&B project/group: $WANDB_PROJECT / $WANDB_GROUP" log "Dry run mode: $DRY_RUN" log "Output base directory (logs + saves): $OUTPUT_BASE_DIR" +log "Warmup steps: $WARMUP_STEPS" +log "Targeted-layer matrix profile: $MATRIX_TARGETED" log "Port plan:" log " shared_vllm: run-api=${SHARED_API_PORT}, vllm=${SHARED_VLLM_PORT}" log " lora_only: run-api=${LORA_ONLY_API_PORT}, vllm=${LORA_ONLY_VLLM_PORT}" @@ -427,6 +448,11 @@ log "GPU plan:" log " shared_vllm: trainer+vllm on GPU ${SHARED_GPU} (required for shared weights)" log " lora_only: trainer GPU ${LORA_ONLY_TRAINER_GPU}, vllm GPU ${LORA_ONLY_VLLM_GPU}" log " lora_restart: trainer GPU ${LORA_RESTART_TRAINER_GPU}, vllm GPU ${LORA_RESTART_VLLM_GPU}" +if [[ -n "$SHARED_LAYER_INDICES" ]]; then + log "Shared-model train layer indices: $SHARED_LAYER_INDICES" +else + log "Shared-model train layer indices: all layers" +fi if [[ -n "$LORA_LAYER_INDICES" ]]; then log "LoRA layer indices: $LORA_LAYER_INDICES" else diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index ce80adc9..a8c5cca1 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -12,6 +12,7 @@ import os import subprocess import sys import time +import logging from typing import Iterable, Optional import requests @@ -21,16 +22,22 @@ from torch.optim import AdamW from .api import check_atropos_api, register_trainer +logger = logging.getLogger(__name__) + + def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer: """ Create optimizer based on config.optimizer setting. Options: - - 'adamw': Standard AdamW (full precision, ~32GB GPU for 8B model) - - 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes) - - 'adafactor': Adafactor without momentum (~8GB GPU, no extra dependencies) + - 'adamw': Standard AdamW + - 'adamw_8bit': 8-bit AdamW from bitsandbytes (requires bitsandbytes) + - 'adafactor': Adafactor optimizer (requires transformers) """ - return create_optimizer_for_params(model.parameters(), config) + trainable_params = [p for p in model.parameters() if p.requires_grad] + if not trainable_params: + raise RuntimeError("No trainable parameters found for optimizer creation.") + return create_optimizer_for_params(trainable_params, config) def create_optimizer_for_params( @@ -38,36 +45,46 @@ def create_optimizer_for_params( ) -> torch.optim.Optimizer: """Create optimizer for a specific parameter iterable.""" params = list(params) + if not params: + raise RuntimeError("Optimizer received an empty parameter list.") if config.optimizer == "adamw_8bit": try: import bitsandbytes as bnb optimizer = bnb.optim.AdamW8bit(params, lr=config.lr) - print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)") + logger.info("[Setup] Using 8-bit AdamW optimizer") return optimizer except ImportError: - print("[Setup] WARNING: bitsandbytes not installed, falling back to AdamW") - print("[Setup] Install with: pip install bitsandbytes") + logger.warning( + "[Setup] bitsandbytes not installed, falling back to AdamW" + ) + logger.info("[Setup] Install with: pip install bitsandbytes") if config.optimizer == "adafactor": try: from transformers.optimization import Adafactor + scale_parameter = getattr(config, "adafactor_scale_parameter", False) + relative_step = getattr(config, "adafactor_relative_step", False) optimizer = Adafactor( params, lr=config.lr, - scale_parameter=False, - relative_step=False, + scale_parameter=scale_parameter, + relative_step=relative_step, + ) + logger.info( + "[Setup] Using Adafactor optimizer (scale_parameter=%s, relative_step=%s)", + scale_parameter, + relative_step, ) - print("[Setup] Using Adafactor (no momentum, saves ~24GB)") return optimizer except ImportError: - print("[Setup] WARNING: transformers Adafactor not available, using AdamW") + logger.warning("[Setup] transformers Adafactor unavailable, using AdamW") # Default: standard AdamW optimizer = AdamW(params, lr=config.lr) - print("[Setup] Using standard AdamW (requires ~32GB for optimizer states)") + logger.info("[Setup] Using standard AdamW optimizer") return optimizer @@ -176,6 +193,7 @@ def train_legacy(config: TrainingConfig): advantage_batches, temperature_batches, config, + step_idx=step, inference_logprob_batches=inference_logprob_batches, ) step_time = time.time() - step_start @@ -322,6 +340,7 @@ def train_shared_vllm(config: TrainingConfig): advantage_batches, temperature_batches, config, + step_idx=step, inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation ) step_time = time.time() - step_start @@ -481,6 +500,7 @@ def train_lora(config: TrainingConfig): advantage_batches, temperature_batches, config, + step_idx=step, inference_logprob_batches=inference_logprob_batches, ) step_time = time.time() - step_start @@ -702,6 +722,7 @@ def train_lora_restart(config: TrainingConfig): advantage_batches, temperature_batches, config, + step_idx=step, inference_logprob_batches=inference_logprob_batches, ) step_time = time.time() - step_start diff --git a/example_trainer/training.py b/example_trainer/training.py index f44cfe84..c26fd428 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -69,7 +69,7 @@ def compute_grpo_loss( temperatures: torch.Tensor, gradient_accumulation_steps: int, inference_logprobs: Optional[torch.Tensor] = None, - kl_coef: float = 0.1, + kl_coef: float = 0.0, clip_eps: float = 0.2, use_reference_logprobs: bool = True, ) -> Tuple[torch.Tensor, dict]: @@ -79,12 +79,12 @@ def compute_grpo_loss( This implements proper GRPO/PPO with: - Importance sampling ratio: policy(a|s) / policy_old(a|s) - PPO-style clipping to prevent large updates - - KL penalty to prevent reward hacking/policy collapse + - Optional KL-like regularization term on sampled actions The loss encourages the model to: - Increase probability for tokens with positive advantages - Decrease probability for tokens with negative advantages - - Stay close to the reference policy (inference-time policy) + - Stay close to the rollout/reference policy on sampled tokens Args: model: The model to compute loss for @@ -94,7 +94,7 @@ def compute_grpo_loss( temperatures: Temperature values [batch, 1, 1] gradient_accumulation_steps: Number of accumulation steps (for scaling) inference_logprobs: Logprobs from inference (π_old), aligned with labels [batch, seq_len] - kl_coef: KL penalty coefficient (beta). Higher = more conservative updates + kl_coef: Coefficient for sampled-token KL-like regularization clip_eps: PPO clipping epsilon. Clips ratio to [1-eps, 1+eps] use_reference_logprobs: If True, use inference_logprobs as reference policy @@ -192,12 +192,13 @@ def compute_grpo_loss( # Average over tokens, then over batch policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean() - # KL penalty: encourage staying close to reference policy - # Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4): - # This estimator is guaranteed to be non-negative (unlike squared log-ratio). + # KL-like sampled-token regularizer: encourages staying close to rollout policy. + # This uses Schulman's non-negative estimator on sampled actions: + # exp(-log_ratio) + log_ratio - 1 + # where log_ratio = log pi(a|s) - log pi_ref(a|s). + # NOTE: this is not full-distribution KL unless evaluated over the full action space. if kl_coef > 0: - # Schulman's unbiased KL estimator: (π_ref/π) - log(π_ref/π) - 1 - # = exp(-log_ratio) + log_ratio - 1 + # Schulman's sampled-token estimator. kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0 kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean() total_loss = ( @@ -290,6 +291,7 @@ def run_training_step( advantage_batches: List[torch.Tensor], temperature_batches: List[torch.Tensor], config: TrainingConfig, + step_idx: int, inference_logprob_batches: Optional[List[torch.Tensor]] = None, ) -> dict: """ @@ -308,7 +310,8 @@ def run_training_step( label_batches: List of label tensors advantage_batches: List of advantage tensors temperature_batches: List of temperature tensors - config: Training configuration (includes kl_coef, clip_eps, use_reference_logprobs) + config: Training configuration (includes kl_coef, clip_eps, warmup_steps) + step_idx: Current global training step (0-based) inference_logprob_batches: Batched logprobs from inference (π_old), aligned with labels Returns: @@ -331,10 +334,20 @@ def run_training_step( all_inference_logprobs: List[torch.Tensor] = [] # Get GRPO hyperparameters from config - kl_coef = getattr(config, "kl_coef", 0.1) + kl_coef = getattr(config, "kl_coef", 0.0) clip_eps = getattr(config, "clip_eps", 0.2) use_reference_logprobs = getattr(config, "use_reference_logprobs", True) + # Apply linear warmup to optimizer LR for early-step stability. + warmup_steps = max(0, int(getattr(config, "warmup_steps", 0))) + if warmup_steps > 0 and step_idx < warmup_steps: + warmup_scale = float(step_idx + 1) / float(max(1, warmup_steps)) + current_lr = float(config.lr) * warmup_scale + else: + current_lr = float(config.lr) + for param_group in optimizer.param_groups: + param_group["lr"] = current_lr + # Accumulate gradients over micro-batches num_batches = len(token_batches) if token_batches else 1 @@ -410,6 +423,7 @@ def run_training_step( result = { "loss": total_loss, + "lr": current_lr, "grad_norm": grad_norm.item() if hasattr(grad_norm, "item") else grad_norm, "pos_logp": total_pos_logp, "neg_logp": total_neg_logp, @@ -515,6 +529,7 @@ def log_metrics( log_dict = { "train/loss": metrics["loss"], "train/grad_norm": metrics["grad_norm"], + "train/lr": metrics.get("lr", 0.0), "train/pos_logp": metrics.get("pos_logp", 0), "train/neg_logp": metrics.get("neg_logp", 0), # GRPO-specific metrics