This commit is contained in:
Jai Suphavadeeprasit 2026-02-03 12:18:54 -05:00
parent 2eaff54351
commit e932369777
7 changed files with 32 additions and 38 deletions

View file

@ -140,10 +140,10 @@ def compute_grpo_loss(
# Check if ref logprobs are negative (as they should be for generated tokens)
# If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used
if ref_at_generated > 0.5:
print(f" [WARNING] ref_logprobs at generated positions avg {ref_at_generated:.3f} (should be negative!)")
print(f" [WARNING] This suggests inference_logprobs alignment is still wrong")
print(f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)")
print(" [WARNING] This suggests inference_logprobs alignment is wrong")
elif abs(ref_at_generated - train_at_generated) > 2.0:
print(f" [DEBUG] Logprob gap (may be OK for first step): ref={ref_at_generated:.3f}, train={train_at_generated:.3f}")
print(f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}")
# Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old)
log_ratio = logp_per_token - ref_logprobs
@ -277,8 +277,6 @@ def run_training_step(
Returns:
Dict of training metrics for this step
"""
global _logprob_alignment_stats
total_loss = 0.0
total_pos_logp = 0.0
total_neg_logp = 0.0
@ -399,8 +397,6 @@ def log_metrics(
extra_metrics: Optional additional metrics to log
benchmark: Whether to show timing/benchmark info
"""
global _logprob_alignment_stats
# Build timing string (only if benchmark enabled)
timing_str = ""
if benchmark:
@ -462,7 +458,7 @@ def log_metrics(
if key in metrics:
log_dict[f"train/{key}"] = metrics[key]
# Add logprob alignment stats (key for shared_vllm validation!)
# Add logprob alignment stats
if _logprob_alignment_stats:
log_dict.update(_logprob_alignment_stats)