mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
clean log
This commit is contained in:
parent
d1b0dee8f7
commit
600c54f5f8
7 changed files with 15 additions and 206 deletions
|
|
@ -377,61 +377,12 @@ def get_data(
|
|||
- inference_logprob_batches are aligned with labels for proper GRPO loss computation
|
||||
"""
|
||||
batches = []
|
||||
_logged_logprob_warning = False
|
||||
empty_polls = 0
|
||||
|
||||
while True:
|
||||
data = get_batch(url=atropos_url)
|
||||
|
||||
if data["batch"] is not None:
|
||||
empty_polls = 0
|
||||
num_groups = len(data["batch"])
|
||||
num_sequences = sum(len(item["tokens"]) for item in data["batch"])
|
||||
max_seq_len = max(
|
||||
max(len(seq) for seq in item["tokens"]) for item in data["batch"]
|
||||
)
|
||||
print(
|
||||
" [Data] received API batch: "
|
||||
f"groups={num_groups} sequences={num_sequences} max_seq_len={max_seq_len}",
|
||||
flush=True,
|
||||
)
|
||||
# DEBUG: Check if inference_logprobs exists in the data
|
||||
if not _logged_logprob_warning:
|
||||
has_logprobs = any(
|
||||
"inference_logprobs" in item for item in data["batch"]
|
||||
)
|
||||
if has_logprobs:
|
||||
# Check if they're non-empty
|
||||
sample_item = next(
|
||||
(
|
||||
item
|
||||
for item in data["batch"]
|
||||
if "inference_logprobs" in item
|
||||
),
|
||||
None,
|
||||
)
|
||||
if sample_item and sample_item.get("inference_logprobs"):
|
||||
sample_lp = (
|
||||
sample_item["inference_logprobs"][0]
|
||||
if sample_item["inference_logprobs"]
|
||||
else []
|
||||
)
|
||||
print(
|
||||
f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
" [Data] ⚠ inference_logprobs key exists but is empty!"
|
||||
)
|
||||
else:
|
||||
print(" [Data] ⚠ NO inference_logprobs in batch data!")
|
||||
print(
|
||||
f" [Data] Keys in first item: {list(data['batch'][0].keys())}"
|
||||
)
|
||||
_logged_logprob_warning = True
|
||||
|
||||
# Process and accumulate batches (now includes batched inference logprobs)
|
||||
print(" [Data] padding / batching API payload...", flush=True)
|
||||
(
|
||||
token_batches,
|
||||
label_batches,
|
||||
|
|
@ -441,12 +392,6 @@ def get_data(
|
|||
distill_token_id_batches,
|
||||
distill_logprob_batches,
|
||||
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
|
||||
batch_shapes = [tuple(tb.shape) for tb in token_batches]
|
||||
print(
|
||||
" [Data] pad_data_to_good_offset done: "
|
||||
f"micro_batches={len(token_batches)} token_batch_shapes={batch_shapes}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Include inference logprob batches in the tuple
|
||||
batches.append(
|
||||
|
|
@ -463,17 +408,7 @@ def get_data(
|
|||
|
||||
elif len(batches) > 0:
|
||||
# Return accumulated batches when no more data
|
||||
print(
|
||||
f" [Data] returning {len(batches)} assembled trainer batch tuple(s)",
|
||||
flush=True,
|
||||
)
|
||||
return batches, None
|
||||
else:
|
||||
# Wait for data
|
||||
empty_polls += 1
|
||||
if empty_polls == 1 or empty_polls % 30 == 0:
|
||||
print(
|
||||
f" [Data] no batch ready yet (polls_without_data={empty_polls})",
|
||||
flush=True,
|
||||
)
|
||||
time.sleep(1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue