mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-01 17:45:16 +00:00
debug
This commit is contained in:
parent
2b5debe0a2
commit
851f0b6e17
2 changed files with 29 additions and 0 deletions
|
|
@ -237,11 +237,28 @@ def get_data(
|
|||
- inference_logprob_batches are aligned with labels for proper GRPO loss computation
|
||||
"""
|
||||
batches = []
|
||||
_logged_logprob_warning = False
|
||||
|
||||
while True:
|
||||
data = get_batch(url=atropos_url)
|
||||
|
||||
if data["batch"] is not None:
|
||||
# DEBUG: Check if inference_logprobs exists in the data
|
||||
if not _logged_logprob_warning:
|
||||
has_logprobs = any("inference_logprobs" in item for item in data["batch"])
|
||||
if has_logprobs:
|
||||
# Check if they're non-empty
|
||||
sample_item = next((item for item in data["batch"] if "inference_logprobs" in item), None)
|
||||
if sample_item and sample_item.get("inference_logprobs"):
|
||||
sample_lp = sample_item["inference_logprobs"][0] if sample_item["inference_logprobs"] else []
|
||||
print(f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})")
|
||||
else:
|
||||
print(f" [Data] ⚠ inference_logprobs key exists but is empty!")
|
||||
else:
|
||||
print(f" [Data] ⚠ NO inference_logprobs in batch data!")
|
||||
print(f" [Data] Keys in first item: {list(data['batch'][0].keys())}")
|
||||
_logged_logprob_warning = True
|
||||
|
||||
# Save batch for debugging
|
||||
with open("temp.json", "w", encoding="utf-8") as f:
|
||||
json.dump(data, f)
|
||||
|
|
|
|||
|
|
@ -214,6 +214,18 @@ def compute_grpo_loss(
|
|||
# Move inference logprobs to correct device/dtype
|
||||
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
|
||||
|
||||
# DEBUG: Check if inference logprobs look valid
|
||||
with torch.no_grad():
|
||||
ref_nonzero = (ref_logprobs != 0).float()
|
||||
ref_nonzero_frac = (ref_nonzero * mask).sum() / mask.sum()
|
||||
ref_mean = (ref_logprobs * mask).sum() / mask.sum()
|
||||
train_mean = (logp_per_token * mask).sum() / mask.sum()
|
||||
if ref_nonzero_frac < 0.5:
|
||||
print(f" [WARNING] Only {ref_nonzero_frac*100:.1f}% of inference logprobs are non-zero!")
|
||||
print(f" [WARNING] This suggests inference_logprobs field may be missing from data")
|
||||
if abs(ref_mean - train_mean) > 1.0:
|
||||
print(f" [DEBUG] Large logprob gap: ref_mean={ref_mean:.3f}, train_mean={train_mean:.3f}")
|
||||
|
||||
# Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old)
|
||||
log_ratio = logp_per_token - ref_logprobs
|
||||
ratio = torch.exp(log_ratio)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue