diff --git a/example_trainer/data.py b/example_trainer/data.py index 74e1ffe1..9f79ff08 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -123,7 +123,7 @@ def pad_data_to_good_offset( # IMPORTANT: inference_logprobs is ALREADY ALIGNED with tokens/masks: # - 1.0 for prompt tokens (masked positions) # - actual negative logprobs for generated tokens - # We just need to pad to match the sequence length, no realignment needed! + # We just need to pad to match the sequence length if extract_inference_logprobs and "inference_logprobs" in item: if i < len(item["inference_logprobs"]): raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32) @@ -140,10 +140,10 @@ def pad_data_to_good_offset( # Shift by 1 to match causal label shift inference_logprobs_padded.append(padded_logprobs[1:]) else: - # No logprobs for this sample, use 1.0 (masked placeholder) + # No logprobs for this sample, use 1.0 inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32)) elif extract_inference_logprobs: - # No inference_logprobs in item, use 1.0 (masked placeholder) + # No inference_logprobs in item, use 1.0 inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32)) # Extract temperature (priority: override > generation_params > group_overrides > 1.0) diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 0b1f4604..a90de4b4 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -1,19 +1,11 @@ #!/usr/bin/env python3 """ -GRPO Trainer - Multi-Mode Entry Point Supports three training modes: - none (legacy): Periodic checkpoint saves + vLLM restarts - shared_vllm: Single-copy mode with CUDA IPC weight sharing - lora_only: LoRA adapter training -For the unified single-command experience (shared_vllm + auto vLLM launch), -use run.py instead: - python example_trainer/run.py --model Qwen/Qwen3-4B --training-steps 20 - -This script requires vLLM to be running separately (except for legacy mode -which manages vLLM internally). - Usage: # Legacy mode (manages vLLM internally) python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct diff --git a/example_trainer/training.py b/example_trainer/training.py index 39853fad..5ab7f2f6 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -23,89 +23,6 @@ from .config import TrainingConfig # Global storage for logprob alignment stats _logprob_alignment_stats: Dict[str, float] = {} -# Global storage for weight verification -_weight_snapshot: Dict[str, float] = {} - - -def verify_vllm_sees_updates(model: torch.nn.Module, vllm_port: int, step: int) -> bool: - """ - Verify that vLLM actually sees weight updates by corrupting a weight - and checking if vLLM's output changes. - - Returns True if vLLM sees updates, False otherwise. - """ - import requests - - try: - # Find embedding layer - embed_param = None - for name, param in model.named_parameters(): - if "embed_tokens" in name: - embed_param = param - break - - if embed_param is None: - return True # Can't verify, assume OK - - test_prompt = "Hello" - vllm_url = f"http://localhost:{vllm_port}" - - # Get baseline - r1 = requests.post( - f"{vllm_url}/generate", - json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0}, - timeout=10, - ) - baseline = r1.json().get("text", [""])[0] if r1.status_code == 200 else None - - if baseline is None: - return True # Can't verify - - # Corrupt weight - original = embed_param.data[0, 0].clone() - embed_param.data[0, 0] = 9999.0 - - # Query vLLM - r2 = requests.post( - f"{vllm_url}/generate", - json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0}, - timeout=10, - ) - corrupted = r2.json().get("text", [""])[0] if r2.status_code == 200 else baseline - - # Restore - embed_param.data[0, 0] = original - - # Check if output changed - sharing_works = (corrupted != baseline) - - if not sharing_works and step > 0: - print(f" [WARN] Step {step}: vLLM may not see weight updates!") - - return sharing_works - - except Exception: - return True # Can't verify, assume OK - - -def snapshot_weights(model: torch.nn.Module) -> Dict[str, float]: - """Take a snapshot of sample weight values for comparison.""" - snapshot = {} - for name, param in model.named_parameters(): - if any(x in name for x in ["layers.0.", "layers.10.", "embed_tokens", "lm_head"]): - snapshot[name] = param.data.flatten()[0].item() - return snapshot - - -def compare_weight_snapshots(old: Dict[str, float], new: Dict[str, float]) -> Dict[str, float]: - """Compare two weight snapshots and return differences.""" - diffs = {} - for name in old: - if name in new: - diffs[name] = abs(new[name] - old[name]) - return diffs - - def setup_wandb(config: TrainingConfig) -> bool: """ Initialize Weights & Biases logging if enabled. @@ -161,7 +78,7 @@ def compute_grpo_loss( Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch. This implements proper GRPO/PPO with: - - Importance sampling ratio: π(a|s) / π_old(a|s) + - Importance sampling ratio: policy(a|s) / policy_old(a|s) - PPO-style clipping to prevent large updates - KL penalty to prevent reward hacking/policy collapse @@ -189,7 +106,7 @@ def compute_grpo_loss( outputs = model(tokens) logits = outputs.logits - # Temperature scaling for training + # Temperature scaling for training otherwise likely ratio is off t = temperatures.to(logits.device, logits.dtype) t = torch.where(t <= 0, torch.ones_like(t), t) scaled_logits = logits / t diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 2d5a8688..cae00156 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -47,10 +47,6 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, List, Optional -# ============================================================================= -# CRITICAL: Set up multiprocessing and vLLM engine BEFORE any CUDA imports -# ============================================================================= - # Default to v0 engine to avoid CUDA fork issues with v1 engine # Users can override with VLLM_USE_V1=1 if needed os.environ.setdefault("VLLM_USE_V1", "0") @@ -168,11 +164,6 @@ except ImportError: logger = init_logger("vllm.entrypoints.api_server") - -# ============================================================================= -# Global State -# ============================================================================= - app = FastAPI() engine: Optional[AsyncLLM] = None