diff --git a/example_trainer/data.py b/example_trainer/data.py index be7667c2..e79fecb0 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -237,11 +237,28 @@ def get_data( - inference_logprob_batches are aligned with labels for proper GRPO loss computation """ batches = [] + _logged_logprob_warning = False while True: data = get_batch(url=atropos_url) if data["batch"] is not None: + # DEBUG: Check if inference_logprobs exists in the data + if not _logged_logprob_warning: + has_logprobs = any("inference_logprobs" in item for item in data["batch"]) + if has_logprobs: + # Check if they're non-empty + sample_item = next((item for item in data["batch"] if "inference_logprobs" in item), None) + if sample_item and sample_item.get("inference_logprobs"): + sample_lp = sample_item["inference_logprobs"][0] if sample_item["inference_logprobs"] else [] + print(f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})") + else: + print(f" [Data] ⚠ inference_logprobs key exists but is empty!") + else: + print(f" [Data] ⚠ NO inference_logprobs in batch data!") + print(f" [Data] Keys in first item: {list(data['batch'][0].keys())}") + _logged_logprob_warning = True + # Save batch for debugging with open("temp.json", "w", encoding="utf-8") as f: json.dump(data, f) diff --git a/example_trainer/training.py b/example_trainer/training.py index 69ff1d9f..f36e9b63 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -214,6 +214,18 @@ def compute_grpo_loss( # Move inference logprobs to correct device/dtype ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype) + # DEBUG: Check if inference logprobs look valid + with torch.no_grad(): + ref_nonzero = (ref_logprobs != 0).float() + ref_nonzero_frac = (ref_nonzero * mask).sum() / mask.sum() + ref_mean = (ref_logprobs * mask).sum() / mask.sum() + train_mean = (logp_per_token * mask).sum() / mask.sum() + if ref_nonzero_frac < 0.5: + print(f" [WARNING] Only {ref_nonzero_frac*100:.1f}% of inference logprobs are non-zero!") + print(f" [WARNING] This suggests inference_logprobs field may be missing from data") + if abs(ref_mean - train_mean) > 1.0: + print(f" [DEBUG] Large logprob gap: ref_mean={ref_mean:.3f}, train_mean={train_mean:.3f}") + # Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old) log_ratio = logp_per_token - ref_logprobs ratio = torch.exp(log_ratio)