This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 16:06:19 -05:00
parent 2b5debe0a2
commit 851f0b6e17
2 changed files with 29 additions and 0 deletions

View file

@ -237,11 +237,28 @@ def get_data(
- inference_logprob_batches are aligned with labels for proper GRPO loss computation
"""
batches = []
_logged_logprob_warning = False
while True:
data = get_batch(url=atropos_url)
if data["batch"] is not None:
# DEBUG: Check if inference_logprobs exists in the data
if not _logged_logprob_warning:
has_logprobs = any("inference_logprobs" in item for item in data["batch"])
if has_logprobs:
# Check if they're non-empty
sample_item = next((item for item in data["batch"] if "inference_logprobs" in item), None)
if sample_item and sample_item.get("inference_logprobs"):
sample_lp = sample_item["inference_logprobs"][0] if sample_item["inference_logprobs"] else []
print(f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})")
else:
print(f" [Data] ⚠ inference_logprobs key exists but is empty!")
else:
print(f" [Data] ⚠ NO inference_logprobs in batch data!")
print(f" [Data] Keys in first item: {list(data['batch'][0].keys())}")
_logged_logprob_warning = True
# Save batch for debugging
with open("temp.json", "w", encoding="utf-8") as f:
json.dump(data, f)