diff --git a/example_trainer/cli.py b/example_trainer/cli.py index b9f13fc0..cb500005 100644 --- a/example_trainer/cli.py +++ b/example_trainer/cli.py @@ -70,6 +70,38 @@ def parse_args() -> argparse.Namespace: "'adamw_cpu' (CPU offload, ~0GB GPU, slower), " "'adafactor' (no momentum, ~8GB GPU)", ) + + # === GRPO/PPO Hyperparameters === + parser.add_argument( + "--kl-coef", + type=float, + default=0.1, + help=( + "KL divergence penalty coefficient (beta). " + "Controls policy deviation from reference. " + "Higher = more conservative, prevents reward hacking. " + "0 = disabled (not recommended)." + ), + ) + parser.add_argument( + "--clip-eps", + type=float, + default=0.2, + help=( + "PPO-style clipping epsilon. " + "Clips importance ratio to [1-eps, 1+eps]. " + "Prevents destabilizing large policy updates." + ), + ) + parser.add_argument( + "--no-reference-logprobs", + action="store_true", + help=( + "Disable use of inference logprobs as reference policy. " + "Falls back to REINFORCE-style updates (not recommended)." + ), + ) + parser.add_argument( "--device", type=str, @@ -265,6 +297,11 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: device=args.device, save_path=args.save_path, checkpoint_interval=getattr(args, "checkpoint_interval", 3), + # GRPO/PPO hyperparameters + kl_coef=getattr(args, "kl_coef", 0.1), + clip_eps=getattr(args, "clip_eps", 0.2), + use_reference_logprobs=not getattr(args, "no_reference_logprobs", False), + # vLLM settings vllm_restart_interval=args.vllm_restart_interval, vllm_port=args.vllm_port, vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization, diff --git a/example_trainer/config.py b/example_trainer/config.py index 858472e2..5b4da672 100644 --- a/example_trainer/config.py +++ b/example_trainer/config.py @@ -40,6 +40,33 @@ class TrainingConfig(BaseModel): "'adafactor' (no momentum, ~8GB GPU)" ) + # === GRPO/PPO Hyperparameters === + kl_coef: float = Field( + 0.1, + description=( + "KL divergence penalty coefficient (beta). " + "Controls how much the policy can deviate from the reference (inference-time) policy. " + "Higher values = more conservative updates, prevents reward hacking. " + "Set to 0 to disable KL penalty (not recommended)." + ), + ) + clip_eps: float = Field( + 0.2, + description=( + "PPO-style clipping epsilon. " + "Clips the importance sampling ratio to [1-eps, 1+eps]. " + "Prevents large policy updates that could destabilize training." + ), + ) + use_reference_logprobs: bool = Field( + True, + description=( + "Whether to use inference logprobs as the reference policy (π_old). " + "When True, implements proper GRPO with importance sampling. " + "When False, falls back to REINFORCE-style updates (not recommended)." + ), + ) + # === Device & Storage === device: str = Field( "cuda" if torch.cuda.is_available() else "cpu", diff --git a/example_trainer/data.py b/example_trainer/data.py index f290adab..be7667c2 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -4,7 +4,9 @@ Data processing utilities for GRPO trainer. Handles data retrieval from Atropos API, padding, batching, and advantage normalization. -Also extracts inference logprobs for alignment validation with training logprobs. +Also extracts inference logprobs for proper GRPO loss computation: +- Inference logprobs serve as π_old (reference policy) for importance sampling +- They are batched and padded to align token-by-token with training labels """ import json @@ -23,11 +25,11 @@ def pad_data_to_good_offset( batch_size: int, extract_inference_logprobs: bool = True, ) -> Tuple[ - List[torch.Tensor], - List[torch.Tensor], - List[torch.Tensor], - List[torch.Tensor], - Optional[List[np.ndarray]], + List[torch.Tensor], # token_batches + List[torch.Tensor], # label_batches + List[torch.Tensor], # advantage_batches + List[torch.Tensor], # temperature_batches + Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels) ]: """ Pad and batch data from the Atropos API. @@ -36,7 +38,7 @@ def pad_data_to_good_offset( - Pads token sequences to nearest multiple of 64 - Normalizes advantage scores - Extracts temperature values - - Optionally extracts inference logprobs for alignment validation + - Extracts and pads inference logprobs for proper GRPO loss computation Args: data: Raw batch data from Atropos API @@ -44,8 +46,12 @@ def pad_data_to_good_offset( extract_inference_logprobs: Whether to extract inference logprobs Returns: - Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs) - inference_logprobs is None if extract_inference_logprobs=False or no logprobs in data + Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches) + inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data + + Note: + inference_logprob_batches are padded with 0.0 at positions where labels == -100. + This allows token-by-token alignment during GRPO loss computation. """ max_token_len = max( [max([len(x) for x in item["tokens"]]) for item in data["batch"]] @@ -66,7 +72,8 @@ def pad_data_to_good_offset( advantages = [] lengths = [] temperatures = [] - inference_logprobs_list: List[np.ndarray] = [] + inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape + has_any_logprobs = False for item in data["batch"]: # Normalize advantage scores @@ -84,15 +91,16 @@ def pad_data_to_good_offset( # Process each sample in the item for i in range(len(item["tokens"])): + seq_len = len(item["tokens"][i]) lengths.append( - math.ceil((len(item["tokens"][i]) - 1) / good_multiple) * good_multiple + math.ceil((seq_len - 1) / good_multiple) * good_multiple ) - # Create labels with padding + # Create labels with padding (-100 for masked positions) label_item = np.concatenate([ np.array(item["masks"][i]), np.full( - max(0, token_setup_len - len(item["tokens"][i])), + max(0, token_setup_len - seq_len), -100, dtype=np.int32, ), @@ -102,7 +110,7 @@ def pad_data_to_good_offset( item["tokens"][i] = np.concatenate([ np.array(item["tokens"][i]), np.zeros( - max(0, token_setup_len - len(item["tokens"][i])), + max(0, token_setup_len - seq_len), dtype=np.int32, ), ]) @@ -111,13 +119,36 @@ def pad_data_to_good_offset( labels.append(label_item[1:]) # Shift by 1 for causal advantages.append(item["scores"][i]) - # Extract inference logprobs for alignment validation - # These come from vLLM during rollout generation + # Extract and pad inference logprobs to match labels shape + # Inference logprobs are ONLY for generated tokens (where labels != -100) + # We need to create a padded array that aligns position-by-position if extract_inference_logprobs and "inference_logprobs" in item: if i < len(item["inference_logprobs"]): - inference_logprobs_list.append( - np.array(item["inference_logprobs"][i], dtype=np.float32) - ) + raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32) + has_any_logprobs = True + + # Create padded logprobs array matching label_item shape + # Fill with 0.0 (will be masked out during loss computation) + padded_logprobs = np.zeros(token_setup_len, dtype=np.float32) + + # The inference logprobs correspond to generated tokens + # Find positions where labels != -100 (generated positions) + mask_arr = np.array(item["masks"][i]) + generated_positions = np.where(mask_arr != -100)[0] + + # Fill in inference logprobs at generated positions + n_to_fill = min(len(raw_logprobs), len(generated_positions)) + if n_to_fill > 0: + padded_logprobs[generated_positions[:n_to_fill]] = raw_logprobs[:n_to_fill] + + # Shift by 1 to match causal label shift + inference_logprobs_padded.append(padded_logprobs[1:]) + else: + # No logprobs for this sample, use zeros + inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32)) + elif extract_inference_logprobs: + # No inference_logprobs in item, use zeros + inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32)) # Extract temperature (priority: override > generation_params > group_overrides > 1.0) t = 1.0 @@ -139,6 +170,7 @@ def pad_data_to_good_offset( label_batches = [] advantage_batches = [] temperature_batches = [] + inference_logprob_batches = [] for i in range(len(input_ids) // batch_size): start = i * batch_size @@ -158,11 +190,17 @@ def pad_data_to_good_offset( np.array(temperatures[start:end], dtype=np.float32) ).view(-1, 1, 1) ) + + # Batch inference logprobs (same shape as labels) + if extract_inference_logprobs and inference_logprobs_padded: + inference_logprob_batches.append( + torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0)) + ) - # Return inference logprobs if available - inference_logprobs = inference_logprobs_list if inference_logprobs_list else None + # Return inference logprob batches if we have any real logprobs + final_logprob_batches = inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None - return token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs + return token_batches, label_batches, advantage_batches, temperature_batches, final_logprob_batches def get_data( @@ -171,8 +209,14 @@ def get_data( atropos_url: str = "http://localhost:8000", extract_inference_logprobs: bool = True, ) -> Tuple[ - List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]], - Optional[List[np.ndarray]], + List[Tuple[ + List[torch.Tensor], # token_batches + List[torch.Tensor], # label_batches + List[torch.Tensor], # advantage_batches + List[torch.Tensor], # temperature_batches + Optional[List[torch.Tensor]], # inference_logprob_batches + ]], + None, # Legacy return (no longer used) ]: """ Fetch and process training data from the Atropos API. @@ -184,15 +228,15 @@ def get_data( batch_size: Size of each training batch seq_len: Maximum sequence length (for reference, not used directly) atropos_url: URL of the Atropos API server - extract_inference_logprobs: Whether to extract inference logprobs for alignment + extract_inference_logprobs: Whether to extract inference logprobs for GRPO loss Returns: - Tuple of (batches, all_inference_logprobs) - - batches: List of processed batch tuples - - all_inference_logprobs: List of inference logprob arrays for alignment validation + Tuple of (batches, None) + - batches: List of processed batch tuples, each containing: + (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches) + - inference_logprob_batches are aligned with labels for proper GRPO loss computation """ batches = [] - all_inference_logprobs: List[np.ndarray] = [] while True: data = get_batch(url=atropos_url) @@ -202,18 +246,16 @@ def get_data( with open("temp.json", "w", encoding="utf-8") as f: json.dump(data, f) - # Process and accumulate batches - token_batches, label_batches, adv_batches, temp_batches, inf_logprobs = \ + # Process and accumulate batches (now includes batched inference logprobs) + token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches = \ pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) - batches.append((token_batches, label_batches, adv_batches, temp_batches)) - - if inf_logprobs: - all_inference_logprobs.extend(inf_logprobs) + # Include inference logprob batches in the tuple + batches.append((token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches)) elif len(batches) > 0: # Return accumulated batches when no more data - return batches, all_inference_logprobs if all_inference_logprobs else None + return batches, None else: # Wait for data time.sleep(1) diff --git a/example_trainer/test_multi_model.py b/example_trainer/test_multi_model.py deleted file mode 100644 index 088deb21..00000000 --- a/example_trainer/test_multi_model.py +++ /dev/null @@ -1,647 +0,0 @@ -#!/usr/bin/env python3 -""" -Multi-model test suite for shared_vllm trainer. - -Tests the trainer against diverse models to verify robustness. -Supports both parallel (different GPUs) and sequential execution. - -With --auto-env, each model gets its own isolated stack: - - run-api (port 8002 + offset) - - gsm8k environment (with model-specific tokenizer) - - vLLM server (port 9001 + offset) - - trainer - -Usage: - # RECOMMENDED: Fully automated parallel test with W&B logging - python -m example_trainer.test_multi_model \ - --models qwen3-4b hermes-8b nemotron-14b devstral-24b \ - --parallel \ - --gpus 0 1 2 3 \ - --auto-env \ - --use-wandb \ - --wandb-project multi-model-test - - # Sequential test on one GPU - python -m example_trainer.test_multi_model \ - --models qwen3-4b hermes-8b \ - --sequential \ - --gpu 0 \ - --auto-env \ - --use-wandb - - # Manual mode (you must start run-api and gsm8k_server yourself) - python -m example_trainer.test_multi_model \ - --models qwen3-4b \ - --sequential \ - --gpu 0 \ - --atropos-url http://localhost:8002 - -Port allocation with --auto-env: - Model 0: run-api:8002, vLLM:9001, GPU from --gpus[0] - Model 1: run-api:8003, vLLM:9002, GPU from --gpus[1] - Model 2: run-api:8004, vLLM:9003, GPU from --gpus[2] - ... -""" - -import argparse -import json -import os -import signal -import subprocess -import sys -import time -from dataclasses import dataclass -from datetime import datetime -from pathlib import Path -from typing import Dict, List, Optional -import threading - - -@dataclass -class ModelConfig: - """Configuration for a test model.""" - name: str - model_id: str - gpu_memory_utilization: float = 0.5 - max_model_len: int = 4096 - dtype: str = "bfloat16" - training_steps: int = 10 - notes: str = "" - - -# Define test models -# Memory estimates for B200 (183GB): -# - Model weights (bf16): 2 bytes/param -# - Gradients: ~same as weights -# - 8-bit optimizer: ~1 byte/param -# - KV cache: depends on max_model_len -TEST_MODELS: Dict[str, ModelConfig] = { - "qwen3-4b": ModelConfig( - name="qwen3-4b", - model_id="Qwen/Qwen3-4B-Instruct-2507", - gpu_memory_utilization=0.4, # ~73GB for vLLM - max_model_len=8192, # Plenty of room on B200 - notes="Small 4B model, good baseline test (~8GB weights)", - ), - "hermes-8b": ModelConfig( - name="hermes-8b", - model_id="NousResearch/Hermes-3-Llama-3.1-8B", - gpu_memory_utilization=0.45, # ~82GB for vLLM - max_model_len=8192, # 8K context fits well - notes="Llama 8B architecture (~16GB weights)", - ), - "nemotron-14b": ModelConfig( - name="nemotron-14b", - model_id="nvidia/Nemotron-Cascade-14B-Thinking", - gpu_memory_utilization=0.5, # ~91GB for vLLM - max_model_len=32768, # 32K context for thinking - notes="14B thinking model (~28GB weights), needs room for long CoT", - ), - "devstral-24b": ModelConfig( - name="devstral-24b", - model_id="mistralai/Devstral-Small-2-24B-Instruct-2512", - gpu_memory_utilization=0.55, # ~100GB for vLLM - max_model_len=16384, # 16K context (conservative for 24B) - notes="Large 24B Mistral (~48GB weights), largest model", - ), -} - - -def get_test_dir(base_dir: str, model_name: str, timestamp: str) -> Path: - """Get unique test directory for a model run.""" - return Path(base_dir) / f"{model_name}_{timestamp}" - - -def start_run_api( - port: int, - log_path: Path, -) -> subprocess.Popen: - """Start a run-api instance on a specific port.""" - cmd = [sys.executable, "-m", "atroposlib.cli.run_api", "--port", str(port)] - - log_file = open(log_path, "w") - process = subprocess.Popen( - cmd, - stdout=log_file, - stderr=subprocess.STDOUT, - # Don't buffer output - bufsize=1, - ) - return process - - -def wait_for_run_api(port: int, timeout: int = 60) -> bool: - """Wait for run-api to be ready.""" - import requests - start = time.time() - while time.time() - start < timeout: - try: - # run-api uses /status or / endpoint, not /health - resp = requests.get(f"http://localhost:{port}/status", timeout=5) - if resp.status_code == 200: - return True - except: - pass - try: - # Fallback to root endpoint - resp = requests.get(f"http://localhost:{port}/", timeout=5) - if resp.status_code == 200: - return True - except: - pass - time.sleep(2) - return False - - -def start_gsm8k_env( - model_id: str, - vllm_port: int, - run_api_port: int, - log_path: Path, - atropos_root: Path, -) -> subprocess.Popen: - """Start a gsm8k environment process for a specific model.""" - gsm8k_script = atropos_root / "environments" / "gsm8k_server.py" - cmd = [ - sys.executable, "-u", str(gsm8k_script), "serve", - "--env.rollout_server_url", f"http://localhost:{run_api_port}", - "--env.tokenizer_name", model_id, - "--env.use_wandb", "false", - "--env.total_steps", "10000", - "--env.batch_size", "64", - "--env.group_size", "8", - "--openai.model_name", model_id, - "--openai.base_url", f"http://localhost:{vllm_port}/v1", - "--openai.api_key", "x", - "--openai.server_type", "openai", - ] - - log_file = open(log_path, "w") - process = subprocess.Popen( - cmd, - stdout=log_file, - stderr=subprocess.STDOUT, - cwd=str(atropos_root), # Run from atropos root - ) - return process - - -def run_model_test( - model_config: ModelConfig, - gpu_id: int, - atropos_url: str, - atropos_port: int, - base_dir: str, - timestamp: str, - training_steps: int, - vllm_port_offset: int = 0, - auto_env: bool = False, - use_wandb: bool = False, - wandb_project: str = "multi-model-test", -) -> Dict: - """ - Run a complete training test for a single model. - - Returns dict with test results. - """ - model_name = model_config.name - test_dir = get_test_dir(base_dir, model_name, timestamp).resolve() # Make absolute - test_dir.mkdir(parents=True, exist_ok=True) - - # Unique paths for this model (all absolute) - vllm_port = 9001 + vllm_port_offset - bridge_config_path = test_dir / "vllm_bridge_config.json" - checkpoint_dir = test_dir / "checkpoints" - log_dir = test_dir / "logs" - log_dir.mkdir(exist_ok=True) - - vllm_log = log_dir / "vllm.log" - trainer_log = log_dir / "trainer.log" - - # Each model gets unique ports - run_api_port = 8002 + vllm_port_offset - - result = { - "model": model_config.model_id, - "model_name": model_name, - "gpu": gpu_id, - "vllm_port": vllm_port, - "run_api_port": run_api_port, - "test_dir": str(test_dir), - "status": "pending", - "error": None, - "start_time": None, - "end_time": None, - "duration_seconds": None, - "real_time_alignment": None, - "final_gpu_memory": None, - } - - print(f"\n{'='*60}") - print(f"[{model_name}] Starting test on GPU {gpu_id}") - print(f"[{model_name}] Model: {model_config.model_id}") - print(f"[{model_name}] vLLM port: {vllm_port}") - print(f"[{model_name}] Test dir: {test_dir}") - print(f"{'='*60}\n") - - result["start_time"] = datetime.now().isoformat() - start_time = time.time() - - env_process = None - run_api_process = None - trainer_process = None - - # Get atropos root directory (used for vLLM and gsm8k scripts) - script_dir = Path(__file__).parent - atropos_root = script_dir.parent.resolve() - - try: - # === Start run-api (if auto_env) === - if auto_env: - run_api_log = log_dir / "run_api.log" - print(f"[{model_name}] Starting run-api on port {run_api_port}...") - run_api_process = start_run_api(run_api_port, run_api_log) - - if not wait_for_run_api(run_api_port, timeout=60): - # Check if process died - if run_api_process.poll() is not None: - print(f"[{model_name}] run-api process exited with code {run_api_process.returncode}") - # Print log contents for debugging - if run_api_log.exists(): - print(f"[{model_name}] run-api log contents:") - print(run_api_log.read_text()[-2000:]) # Last 2000 chars - raise RuntimeError(f"run-api failed to start on port {run_api_port}") - print(f"[{model_name}] ✓ run-api ready on port {run_api_port}") - - # Update atropos_url to use this model's run-api - atropos_url = f"http://localhost:{run_api_port}" - - # === Start gsm8k Environment (if auto_env) === - if auto_env: - env_log = log_dir / "env.log" - print(f"[{model_name}] Starting gsm8k environment (tokenizer: {model_config.model_id})...") - env_process = start_gsm8k_env( - model_config.model_id, vllm_port, run_api_port, env_log, atropos_root - ) - time.sleep(10) # Give it time to initialize and connect - print(f"[{model_name}] ✓ gsm8k environment started") - - # === Start Unified vLLM + Trainer (run.py) === - # Using run.py ensures vLLM is a CHILD of the trainer process, - # which is required for CUDA IPC with ptrace_scope=1 - run_script = script_dir / "run.py" - - # Don't use CUDA_VISIBLE_DEVICES - use --device instead - # run.py sets CUDA_VISIBLE_DEVICES internally based on --device - run_env = os.environ.copy() - run_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - - run_cmd = [ - sys.executable, "-u", str(run_script), - "--model", model_config.model_id, - "--device", f"cuda:{gpu_id}", # This controls GPU selection - "--vllm-port", str(vllm_port), - "--gpu-memory-utilization", str(model_config.gpu_memory_utilization), - "--max-model-len", str(model_config.max_model_len), - "--dtype", model_config.dtype, - "--atropos-url", atropos_url, - "--training-steps", str(training_steps), - "--optimizer", "adamw_8bit", - "--save-path", str(checkpoint_dir), - "--checkpoint-interval", "5", - "--log-dir", str(log_dir), - ] - - # Add wandb flags if enabled - if use_wandb: - run_cmd.extend(["--use-wandb", "--wandb-project", wandb_project]) - - print(f"[{model_name}] Starting unified trainer (vLLM + GRPO) for {training_steps} steps...") - with open(trainer_log, "w") as tlog: - trainer_process = subprocess.Popen( - run_cmd, - env=run_env, - stdout=tlog, - stderr=subprocess.STDOUT, - cwd=str(atropos_root), # Run from atropos root - ) - trainer_process.wait() - - if trainer_process.returncode != 0: - raise RuntimeError(f"Unified trainer exited with code {trainer_process.returncode}") - - result["status"] = "success" - print(f"[{model_name}] ✓ Training completed successfully!") - - # Parse trainer log for metrics - try: - with open(trainer_log, "r") as f: - log_content = f.read() - - # Extract real-time alignment - if "Mean diff:" in log_content: - import re - match = re.search(r"Mean diff: ([\d.]+)", log_content) - if match: - result["real_time_alignment"] = float(match.group(1)) - - # Extract final GPU memory - if "GPU mem:" in log_content: - matches = re.findall(r"GPU mem: ([\d.]+)GB", log_content) - if matches: - result["final_gpu_memory"] = float(matches[-1]) - except Exception as e: - print(f"[{model_name}] Warning: Could not parse log: {e}") - - except Exception as e: - result["status"] = "failed" - result["error"] = str(e) - print(f"[{model_name}] ✗ Test failed: {e}") - import traceback - traceback.print_exc() - - finally: - # Note: vLLM is managed by run.py and cleaned up automatically - - # Cleanup gsm8k environment - if env_process and env_process.poll() is None: - print(f"[{model_name}] Terminating gsm8k environment...") - env_process.terminate() - try: - env_process.wait(timeout=10) - except subprocess.TimeoutExpired: - env_process.kill() - - # Cleanup run-api - if run_api_process and run_api_process.poll() is None: - print(f"[{model_name}] Terminating run-api...") - run_api_process.terminate() - try: - run_api_process.wait(timeout=10) - except subprocess.TimeoutExpired: - run_api_process.kill() - - result["end_time"] = datetime.now().isoformat() - result["duration_seconds"] = time.time() - start_time - - return result - - -def run_parallel_tests( - models: List[ModelConfig], - gpu_ids: List[int], - atropos_url: str, - atropos_port: int, - base_dir: str, - training_steps: int, - auto_env: bool = False, - use_wandb: bool = False, - wandb_project: str = "multi-model-test", -) -> List[Dict]: - """Run tests for multiple models in parallel.""" - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - results = [] - threads = [] - result_lock = threading.Lock() - - def run_and_store(model, gpu, port_offset): - result = run_model_test( - model, gpu, atropos_url, atropos_port, base_dir, timestamp, - training_steps, port_offset, auto_env, use_wandb, wandb_project - ) - with result_lock: - results.append(result) - - # Start threads - for i, (model, gpu) in enumerate(zip(models, gpu_ids)): - t = threading.Thread(target=run_and_store, args=(model, gpu, i)) - t.start() - threads.append(t) - time.sleep(5) # Stagger starts slightly - - # Wait for all to complete - for t in threads: - t.join() - - return results - - -def run_sequential_tests( - models: List[ModelConfig], - gpu_id: int, - atropos_url: str, - atropos_port: int, - base_dir: str, - training_steps: int, - auto_env: bool = False, - use_wandb: bool = False, - wandb_project: str = "multi-model-test", -) -> List[Dict]: - """Run tests for multiple models sequentially on one GPU.""" - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - results = [] - - for i, model in enumerate(models): - result = run_model_test( - model, gpu_id, atropos_url, atropos_port, base_dir, timestamp, - training_steps, port_offset=0, auto_env=auto_env, - use_wandb=use_wandb, wandb_project=wandb_project - ) - results.append(result) - - # Give GPU time to fully release memory - time.sleep(10) - - return results - - -def print_summary(results: List[Dict]): - """Print summary of test results.""" - print("\n" + "="*80) - print("TEST SUMMARY") - print("="*80) - - for r in results: - status_icon = "✓" if r["status"] == "success" else "✗" - duration = f"{r['duration_seconds']:.1f}s" if r['duration_seconds'] else "N/A" - alignment = f"{r['real_time_alignment']:.4f}" if r['real_time_alignment'] else "N/A" - memory = f"{r['final_gpu_memory']:.1f}GB" if r['final_gpu_memory'] else "N/A" - - print(f"\n{status_icon} {r['model_name']}") - print(f" Model: {r['model']}") - print(f" GPU: {r['gpu']}, vLLM port: {r['vllm_port']}, run-api port: {r.get('run_api_port', 'N/A')}") - print(f" Status: {r['status']}") - print(f" Duration: {duration}") - print(f" Real-time alignment: {alignment}") - print(f" GPU memory: {memory}") - if r["error"]: - print(f" Error: {r['error']}") - print(f" Logs: {r['test_dir']}/logs/") - - # Summary stats - successes = sum(1 for r in results if r["status"] == "success") - failures = len(results) - successes - - print(f"\n{'='*80}") - print(f"TOTAL: {successes} passed, {failures} failed") - print("="*80) - - -def main(): - parser = argparse.ArgumentParser( - description="Multi-model test suite for shared_vllm trainer", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Run all models in parallel (one per GPU) - python -m example_trainer.test_multi_model --parallel - - # Run specific models - python -m example_trainer.test_multi_model --models hermes-8b qwen3-4b --parallel - - # Run sequentially on GPU 0 - python -m example_trainer.test_multi_model --sequential --gpu 0 - -Available models: """ + ", ".join(TEST_MODELS.keys()) - ) - - parser.add_argument( - "--models", - nargs="+", - choices=list(TEST_MODELS.keys()), - default=["qwen3-4b", "hermes-8b"], - help="Models to test", - ) - parser.add_argument( - "--parallel", - action="store_true", - help="Run models in parallel on different GPUs", - ) - parser.add_argument( - "--sequential", - action="store_true", - help="Run models sequentially on one GPU", - ) - parser.add_argument( - "--gpus", - type=int, - nargs="+", - default=None, - help="GPU IDs to use (for parallel mode)", - ) - parser.add_argument( - "--gpu", - type=int, - default=0, - help="GPU ID (for sequential mode)", - ) - parser.add_argument( - "--atropos-url", - type=str, - default="http://localhost:8002", - help="Atropos API URL", - ) - parser.add_argument( - "--atropos-port", - type=int, - default=8002, - help="Atropos API port (for spawning multiple if needed)", - ) - parser.add_argument( - "--training-steps", - type=int, - default=10, - help="Number of training steps per model", - ) - parser.add_argument( - "--output-dir", - type=str, - default="./multi_model_tests", - help="Base directory for test outputs", - ) - parser.add_argument( - "--auto-env", - action="store_true", - help="Automatically start run-api and gsm8k environment for each model", - ) - parser.add_argument( - "--use-wandb", - action="store_true", - help="Enable Weights & Biases logging for training runs", - ) - parser.add_argument( - "--wandb-project", - type=str, - default="multi-model-test", - help="W&B project name for logging", - ) - - args = parser.parse_args() - - if not args.parallel and not args.sequential: - args.sequential = True # Default to sequential - - # Get model configs - models = [TEST_MODELS[name] for name in args.models] - - print(f"\n{'#'*60}") - print("# MULTI-MODEL SHARED_VLLM TRAINER TEST SUITE") - print(f"{'#'*60}") - print(f"\nModels to test: {[m.name for m in models]}") - print(f"Mode: {'Parallel' if args.parallel else 'Sequential'}") - print(f"Training steps per model: {args.training_steps}") - print(f"Output directory: {args.output_dir}") - print(f"Atropos URL: {args.atropos_url}") - - # Run tests - if args.auto_env: - print(f"Auto-env: Will start gsm8k environment per model") - - if args.parallel: - gpus = args.gpus or list(range(len(models))) - if len(gpus) < len(models): - print(f"\nWarning: Not enough GPUs ({len(gpus)}) for models ({len(models)})") - print("Some models will share GPUs") - gpus = gpus * (len(models) // len(gpus) + 1) - - print(f"Using GPUs: {gpus[:len(models)]}") - if args.use_wandb: - print(f"W&B logging enabled (project: {args.wandb_project})") - results = run_parallel_tests( - models, gpus[:len(models)], - args.atropos_url, args.atropos_port, - args.output_dir, args.training_steps, - auto_env=args.auto_env, - use_wandb=args.use_wandb, - wandb_project=args.wandb_project, - ) - else: - print(f"Using GPU: {args.gpu}") - if args.use_wandb: - print(f"W&B logging enabled (project: {args.wandb_project})") - results = run_sequential_tests( - models, args.gpu, - args.atropos_url, args.atropos_port, - args.output_dir, args.training_steps, - auto_env=args.auto_env, - use_wandb=args.use_wandb, - wandb_project=args.wandb_project, - ) - - # Print summary - print_summary(results) - - # Save results to JSON - results_file = Path(args.output_dir) / f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" - results_file.parent.mkdir(parents=True, exist_ok=True) - with open(results_file, "w") as f: - json.dump(results, f, indent=2) - print(f"\nResults saved to: {results_file}") - - # Exit with error code if any failed - if any(r["status"] != "success" for r in results): - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/example_trainer/training.py b/example_trainer/training.py index 79c82a9c..5b689cc8 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -153,13 +153,22 @@ def compute_grpo_loss( temperatures: torch.Tensor, gradient_accumulation_steps: int, inference_logprobs: Optional[torch.Tensor] = None, + kl_coef: float = 0.1, + clip_eps: float = 0.2, + use_reference_logprobs: bool = True, ) -> Tuple[torch.Tensor, dict]: """ Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch. - The GRPO loss encourages the model to: + This implements proper GRPO/PPO with: + - Importance sampling ratio: π(a|s) / π_old(a|s) + - PPO-style clipping to prevent large updates + - KL penalty to prevent reward hacking/policy collapse + + The loss encourages the model to: - Increase probability for tokens with positive advantages - Decrease probability for tokens with negative advantages + - Stay close to the reference policy (inference-time policy) Args: model: The model to compute loss for @@ -168,7 +177,10 @@ def compute_grpo_loss( advantages: Advantage values [batch, 1] temperatures: Temperature values [batch, 1, 1] gradient_accumulation_steps: Number of accumulation steps (for scaling) - inference_logprobs: Optional logprobs from inference for alignment check + inference_logprobs: Logprobs from inference (π_old), aligned with labels [batch, seq_len] + kl_coef: KL penalty coefficient (beta). Higher = more conservative updates + clip_eps: PPO clipping epsilon. Clips ratio to [1-eps, 1+eps] + use_reference_logprobs: If True, use inference_logprobs as reference policy Returns: Tuple of (loss tensor, metrics dict) @@ -177,14 +189,14 @@ def compute_grpo_loss( outputs = model(tokens) logits = outputs.logits - # Temperature scaling + # Temperature scaling for training t = temperatures.to(logits.device, logits.dtype) t = torch.where(t <= 0, torch.ones_like(t), t) - logits = logits / t + scaled_logits = logits / t - # Log probabilities per token + # Log probabilities per token (current policy π) logp_per_token = -F.cross_entropy( - logits.view(-1, logits.size(-1)), + scaled_logits.view(-1, scaled_logits.size(-1)), labels.view(-1), reduction="none", ignore_index=-100, @@ -192,39 +204,103 @@ def compute_grpo_loss( # Masking based on labels != -100 mask = (labels != -100).float() + mask_sum = mask.sum(dim=-1).clamp_min(1e-8) - # Compute metrics (no grad needed) + # Expand advantages to match token shape [batch, 1] -> [batch, seq_len] + adv_expanded = advantages.expand_as(logp_per_token).to(logp_per_token.device) + + # === GRPO/PPO Loss Computation === + if use_reference_logprobs and inference_logprobs is not None: + # Move inference logprobs to correct device/dtype + ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype) + + # Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old) + log_ratio = logp_per_token - ref_logprobs + ratio = torch.exp(log_ratio) + + # PPO-style clipping + clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) + + # Surrogate objectives + surr1 = ratio * adv_expanded + surr2 = clipped_ratio * adv_expanded + + # Pessimistic bound: min for positive advantages, max for negative + # This is equivalent to: -min(ratio * A, clipped_ratio * A) when A > 0 + # -max(ratio * A, clipped_ratio * A) when A < 0 + policy_loss_per_token = -torch.where( + adv_expanded >= 0, + torch.min(surr1, surr2), + torch.max(surr1, surr2), + ) + + # Average over tokens, then over batch + policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean() + + # KL penalty: encourage staying close to reference policy + # KL(π || π_ref) ≈ log(π/π_ref) = log_ratio (when π_ref is the reference) + # We use the approximation: KL ≈ (ratio - 1) - log(ratio) + # But simpler: just penalize squared log-ratio which is symmetric + if kl_coef > 0: + # Approximate KL using (log_ratio)^2 / 2 (Taylor expansion) + # Or just use log_ratio directly as a penalty + kl_per_token = log_ratio.pow(2) # Squared for symmetric penalty + kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean() + total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps + else: + kl_penalty = torch.tensor(0.0, device=logp_per_token.device) + total_loss = policy_loss / gradient_accumulation_steps + + # Compute metrics for logging + with torch.no_grad(): + # Fraction of tokens where ratio was clipped + clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float() + clipped_fraction = (clipped_fraction * mask).sum() / mask.sum() + + # Mean ratio and KL for monitoring + mean_ratio = (ratio * mask).sum() / mask.sum() + mean_kl = (log_ratio.pow(2) * mask).sum() / mask.sum() + + # For backward compatibility: collect training logprobs + raw_logp_per_token = -F.cross_entropy( + outputs.logits.view(-1, outputs.logits.size(-1)), + labels.view(-1), + reduction="none", + ignore_index=-100, + ).view(labels.shape) + training_logprobs_flat = raw_logp_per_token[mask.bool()].detach() + else: + # Fallback: REINFORCE-style (no reference policy) + # This is what the original code did - NOT recommended! + print(" [WARNING] No reference logprobs - using REINFORCE (may cause reward hacking!)") + + # Simple policy gradient: -log(π) * A + policy_loss = ((-logp_per_token * mask * adv_expanded).sum(dim=-1) / mask_sum).mean() + total_loss = policy_loss / gradient_accumulation_steps + kl_penalty = torch.tensor(0.0, device=logp_per_token.device) + + with torch.no_grad(): + clipped_fraction = torch.tensor(0.0) + mean_ratio = torch.tensor(1.0) + mean_kl = torch.tensor(0.0) + raw_logp_per_token = -F.cross_entropy( + outputs.logits.view(-1, outputs.logits.size(-1)), + labels.view(-1), + reduction="none", + ignore_index=-100, + ).view(labels.shape) + training_logprobs_flat = raw_logp_per_token[mask.bool()].detach() + + # === Compute Additional Metrics === with torch.no_grad(): pos = (advantages > 0).float() neg = (advantages <= 0).float() mask_float = mask.to(logp_per_token.dtype) - mask_sum = mask_float.sum(dim=-1).clamp_min(1e-8) avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum pos_logp = (logp_per_token * pos).mean().item() neg_logp = (logp_per_token * neg).mean().item() - # For alignment check: compute logprobs WITHOUT temperature scaling - # This allows fair comparison with inference logprobs (which are at temp=1.0) - raw_logp_per_token = -F.cross_entropy( - outputs.logits.view(-1, outputs.logits.size(-1)), # Use original logits, not temp-scaled - labels.view(-1), - reduction="none", - ignore_index=-100, - ).view(labels.shape) - - # Collect raw training logprobs for masked positions (generated tokens only) - # Keep as PyTorch tensor (supports bfloat16 natively) - training_logprobs_flat = raw_logp_per_token[mask.bool()].detach() - - # GRPO loss: weighted log probabilities by advantages - grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach()) - grpo_loss = ( - ((-grpo_loss_term * mask).sum(-1) / mask.sum(-1)) - * advantages.to(logp_per_token.device) - ).mean() / gradient_accumulation_steps - - # Compute a more interpretable loss metric (advantage-weighted logprobs) - with torch.no_grad(): + # Interpretable metric: advantage-weighted average logprob interpretable_loss = (avg_logp * advantages.squeeze()).mean().item() metrics = { @@ -233,11 +309,16 @@ def compute_grpo_loss( "avg_logp": avg_logp, "pos_count": pos.sum().item(), "neg_count": neg.sum().item(), - "training_logprobs": training_logprobs_flat, # For alignment check - "interpretable_loss": interpretable_loss, # More meaningful metric + "training_logprobs": training_logprobs_flat, + "interpretable_loss": interpretable_loss, + # GRPO-specific metrics + "kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty, + "mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio, + "mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl, + "clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction, } - return grpo_loss, metrics + return total_loss, metrics def compute_logprob_alignment( @@ -309,17 +390,16 @@ def run_training_step( advantage_batches: List[torch.Tensor], temperature_batches: List[torch.Tensor], config: TrainingConfig, - inference_logprobs: Optional[List[np.ndarray]] = None, + inference_logprob_batches: Optional[List[torch.Tensor]] = None, ) -> dict: """ Run a single training step with gradient accumulation. Performs: - 1. Forward pass through all micro-batches + 1. Forward pass through all micro-batches with proper GRPO loss 2. Backward pass with gradient accumulation 3. Gradient clipping 4. Optimizer step - 5. (Optional) Logprob alignment check Args: model: The model to train @@ -328,8 +408,8 @@ def run_training_step( label_batches: List of label tensors advantage_batches: List of advantage tensors temperature_batches: List of temperature tensors - config: Training configuration - inference_logprobs: Optional logprobs from inference for alignment check + config: Training configuration (includes kl_coef, clip_eps, use_reference_logprobs) + inference_logprob_batches: Batched logprobs from inference (π_old), aligned with labels Returns: Dict of training metrics for this step @@ -341,16 +421,32 @@ def run_training_step( total_neg_logp = 0.0 total_pos = 0.0 total_neg = 0.0 + total_kl_penalty = 0.0 + total_mean_ratio = 0.0 + total_mean_kl = 0.0 + total_clipped_fraction = 0.0 grad_norm = 0.0 all_training_logprobs: List[torch.Tensor] = [] + # Get GRPO hyperparameters from config + kl_coef = getattr(config, 'kl_coef', 0.1) + clip_eps = getattr(config, 'clip_eps', 0.2) + use_reference_logprobs = getattr(config, 'use_reference_logprobs', True) + # Accumulate gradients over micro-batches - for tokens, labels, advantages, temperatures in zip( + num_batches = len(token_batches) if token_batches else 1 + + for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(zip( token_batches, label_batches, advantage_batches, temperature_batches - ): + )): tokens = tokens.to(config.device) labels = labels.to(config.device) advantages = advantages.to(config.device) + + # Get corresponding inference logprobs batch if available + inf_logprobs = None + if inference_logprob_batches is not None and batch_idx < len(inference_logprob_batches): + inf_logprobs = inference_logprob_batches[batch_idx] loss, metrics = compute_grpo_loss( model, @@ -359,6 +455,10 @@ def run_training_step( advantages, temperatures, config.gradient_accumulation_steps, + inference_logprobs=inf_logprobs, + kl_coef=kl_coef, + clip_eps=clip_eps, + use_reference_logprobs=use_reference_logprobs, ) loss.backward() @@ -368,7 +468,13 @@ def run_training_step( total_pos += metrics["pos_count"] total_neg += metrics["neg_count"] - # Collect training logprobs for alignment check + # Accumulate GRPO-specific metrics + total_kl_penalty += metrics.get("kl_penalty", 0.0) + total_mean_ratio += metrics.get("mean_ratio", 1.0) + total_mean_kl += metrics.get("mean_kl", 0.0) + total_clipped_fraction += metrics.get("clipped_fraction", 0.0) + + # Collect training logprobs for alignment monitoring if "training_logprobs" in metrics: all_training_logprobs.append(metrics["training_logprobs"]) @@ -380,8 +486,7 @@ def run_training_step( # Help prevent memory fragmentation torch.cuda.empty_cache() - # Normalize metrics by count - num_batches = len(token_batches) if token_batches else 1 + # Normalize metrics by batch count if total_pos > 0: total_pos_logp /= num_batches if total_neg > 0: @@ -394,18 +499,22 @@ def run_training_step( "neg_logp": total_neg_logp, "pos_count": total_pos, "neg_count": total_neg, + # GRPO-specific metrics (averaged over batches) + "kl_penalty": total_kl_penalty / num_batches, + "mean_ratio": total_mean_ratio / num_batches, + "mean_kl": total_mean_kl / num_batches, + "clipped_fraction": total_clipped_fraction / num_batches, } - # Compute logprob alignment stats - # NOTE: This comparison is approximate - inference and training logprobs - # come from different batching, so token-by-token alignment isn't possible. - # The real-time test at startup is the reliable alignment check. - if inference_logprobs is not None and all_training_logprobs: - alignment_stats = compute_logprob_alignment( - inference_logprobs, all_training_logprobs, debug=False - ) - _logprob_alignment_stats.update(alignment_stats) - result["logprob_alignment"] = alignment_stats + # Compute logprob alignment stats for monitoring + # NOTE: Now that we use proper GRPO, this is less critical + # but still useful for debugging weight sharing issues + if all_training_logprobs: + # Store training logprob stats + train_flat = torch.cat(all_training_logprobs) + if train_flat.numel() > 0: + _logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item() + _logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item() return result @@ -441,19 +550,27 @@ def log_metrics( if "gpu_memory_gb" in metrics: timing_str += f", GPU mem: {metrics['gpu_memory_gb']:.2f}GB" - # Show interpretable loss (advantage-weighted logprobs) if available - interp_loss = metrics.get("interpretable_loss") - if interp_loss is not None: - print(f" AdvWeightedLogP: {interp_loss:.4f}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}") - else: - loss_str = ( - f"{metrics['loss']:.6f}" - if abs(metrics["loss"]) < 0.01 - else f"{metrics['loss']:.4f}" - ) - print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}") + # Primary metrics line: Loss and grad norm + loss_str = ( + f"{metrics['loss']:.6f}" + if abs(metrics["loss"]) < 0.01 + else f"{metrics['loss']:.4f}" + ) + print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}") - # Show GRPO-specific metrics if available + # GRPO metrics line: KL, ratio, clipping + kl_penalty = metrics.get("kl_penalty", 0) + mean_ratio = metrics.get("mean_ratio", 1.0) + mean_kl = metrics.get("mean_kl", 0) + clipped_frac = metrics.get("clipped_fraction", 0) + + if kl_penalty > 0 or mean_kl > 0: + print( + f" GRPO: KL={mean_kl:.4f}, ratio={mean_ratio:.3f}, " + f"clipped={clipped_frac*100:.1f}%" + ) + + # Advantage distribution if "pos_count" in metrics or "neg_count" in metrics: pos_count = metrics.get("pos_count", 0) neg_count = metrics.get("neg_count", 0) @@ -463,24 +580,6 @@ def log_metrics( f" Advantages: +{int(pos_count)} / -{int(neg_count)}, " f"LogP: pos={pos_logp:.3f}, neg={neg_logp:.3f}" ) - - # Show logprob alignment stats (important for shared_vllm validation!) - if "logprob_alignment" in metrics: - alignment = metrics["logprob_alignment"] - if "logprobs/diff" in alignment: - diff = alignment["logprobs/diff"] - inf_mean = alignment.get("logprobs/inference_mean", 0) - train_mean = alignment.get("logprobs/training_mean", 0) - - # NOTE: This comparison has a fundamental timing issue! - # - inference_logprobs: from vLLM at generation time (possibly stale) - # - training_logprobs: from trainer's current forward pass - # After training starts, weights change, making comparison invalid. - # - # NOTE: This diff is just for monitoring, not validation! - # The real-time test at startup is the reliable alignment check. - # This diff will naturally drift as training progresses (expected). - print(f" LogProb Stats: inf_mean={inf_mean:.4f}, train_mean={train_mean:.4f}") if use_wandb: log_dict = { @@ -488,6 +587,11 @@ def log_metrics( "train/grad_norm": metrics["grad_norm"], "train/pos_logp": metrics.get("pos_logp", 0), "train/neg_logp": metrics.get("neg_logp", 0), + # GRPO-specific metrics + "grpo/kl_penalty": kl_penalty, + "grpo/mean_ratio": mean_ratio, + "grpo/mean_kl": mean_kl, + "grpo/clipped_fraction": clipped_frac, } # Add timing metrics if present for key in ["step_time", "sync_time", "data_fetch_time",