diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index c005a924..e072829b 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -203,12 +203,14 @@ def train_legacy(config: TrainingConfig): for step in range(config.training_steps): print(f"\nStep {step+1}/{config.training_steps}") - # Fetch data + # Fetch data (with inference logprobs for proper GRPO) data_fetch_start = time.time() if len(batches) == 0: batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url, - extract_inference_logprobs=False) - token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0) + extract_inference_logprobs=True) + batch_data = batches.pop(0) + token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4] + inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -217,12 +219,13 @@ def train_legacy(config: TrainingConfig): if should_sync: terminate_vllm_process() - # Training step + # Training step (with proper GRPO using inference logprobs) step_start = time.time() metrics = run_training_step( model, optimizer, token_batches, label_batches, advantage_batches, temperature_batches, config, + inference_logprob_batches=inference_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) @@ -518,34 +521,32 @@ def train_shared_vllm(config: TrainingConfig): # === Training Loop === batches = [] - inference_logprobs = None for step in range(config.training_steps): print(f"\nStep {step+1}/{config.training_steps}") - # Fetch data (with inference logprobs for alignment check) + # Fetch data (with inference logprobs for proper GRPO loss) data_fetch_start = time.time() if len(batches) == 0: - batches, inference_logprobs = get_data( + batches, _ = get_data( config.batch_size, config.seq_len, config.atropos_url, - extract_inference_logprobs=True, # Enable logprob alignment check + extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs ) - token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0) + batch_data = batches.pop(0) + token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4] + inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) - # Training step (with logprob alignment check) + # Training step with proper GRPO (importance sampling + KL penalty) step_start = time.time() metrics = run_training_step( model, optimizer, token_batches, label_batches, advantage_batches, temperature_batches, config, - inference_logprobs=inference_logprobs, # Pass for alignment validation + inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) - - # Clear inference logprobs after use (will be refreshed with new data) - inference_logprobs = None # GPU memory tracking gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 @@ -652,21 +653,24 @@ def train_lora(config: TrainingConfig): for step in range(config.training_steps): print(f"\nStep {step+1}/{config.training_steps}") - # Fetch data + # Fetch data (with inference logprobs for proper GRPO) data_fetch_start = time.time() if len(batches) == 0: batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url, - extract_inference_logprobs=False) - token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0) + extract_inference_logprobs=True) + batch_data = batches.pop(0) + token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4] + inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) - # Training step + # Training step with proper GRPO step_start = time.time() metrics = run_training_step( model, optimizer, token_batches, label_batches, advantage_batches, temperature_batches, config, + inference_logprob_batches=inference_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) diff --git a/example_trainer/training.py b/example_trainer/training.py index 5b689cc8..69ff1d9f 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -238,13 +238,20 @@ def compute_grpo_loss( policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean() # KL penalty: encourage staying close to reference policy - # KL(π || π_ref) ≈ log(π/π_ref) = log_ratio (when π_ref is the reference) - # We use the approximation: KL ≈ (ratio - 1) - log(ratio) - # But simpler: just penalize squared log-ratio which is symmetric + # Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4): + # D_KL(π_θ || π_ref) = (π_ref / π_θ) - log(π_ref / π_θ) - 1 + # + # In terms of log probabilities: + # log_ratio = log π_θ - log π_ref (what we computed above) + # ratio_ref_over_pi = exp(-log_ratio) = π_ref / π_θ + # kl = ratio_ref_over_pi - log(ratio_ref_over_pi) - 1 + # = exp(-log_ratio) + log_ratio - 1 + # + # This estimator is guaranteed to be non-negative (unlike squared log-ratio). if kl_coef > 0: - # Approximate KL using (log_ratio)^2 / 2 (Taylor expansion) - # Or just use log_ratio directly as a penalty - kl_per_token = log_ratio.pow(2) # Squared for symmetric penalty + # Schulman's unbiased KL estimator: (π_ref/π) - log(π_ref/π) - 1 + # = exp(-log_ratio) + log_ratio - 1 + kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0 kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean() total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps else: @@ -257,9 +264,11 @@ def compute_grpo_loss( clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float() clipped_fraction = (clipped_fraction * mask).sum() / mask.sum() - # Mean ratio and KL for monitoring + # Mean ratio and KL for monitoring (using Schulman's estimator) mean_ratio = (ratio * mask).sum() / mask.sum() - mean_kl = (log_ratio.pow(2) * mask).sum() / mask.sum() + # Schulman KL: exp(-log_ratio) + log_ratio - 1 + schulman_kl = torch.exp(-log_ratio) + log_ratio - 1.0 + mean_kl = (schulman_kl * mask).sum() / mask.sum() # For backward compatibility: collect training logprobs raw_logp_per_token = -F.cross_entropy(