clean log

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 12:09:08 -04:00
parent d1b0dee8f7
commit 600c54f5f8
7 changed files with 15 additions and 206 deletions

View file

@ -100,34 +100,15 @@ def get_batch(url: str = "http://localhost:8000"):
Raises:
RuntimeError: If trainer is not registered or other API error
"""
try:
response = requests.get(
f"{url}/batch",
headers={
"X-Atropos-Client": "trainer",
"X-Atropos-Pid": str(os.getpid()),
},
timeout=10,
)
print(
f" [Trainer/API] GET /batch status={response.status_code}",
flush=True,
)
data = response.json()
batch = data.get("batch")
if batch is None:
print(" [Trainer/API] parsed batch=None", flush=True)
else:
num_groups = len(batch)
num_sequences = sum(len(item["tokens"]) for item in batch)
print(
" [Trainer/API] parsed batch payload: "
f"groups={num_groups} sequences={num_sequences}",
flush=True,
)
except Exception as exc:
print(f" [Trainer/API] GET /batch failed: {exc!r}", flush=True)
raise
response = requests.get(
f"{url}/batch",
headers={
"X-Atropos-Client": "trainer",
"X-Atropos-Pid": str(os.getpid()),
},
timeout=10,
)
data = response.json()
# Check if there was an error (trainer not registered)
if data.get("status") == "error":

View file

@ -377,61 +377,12 @@ def get_data(
- inference_logprob_batches are aligned with labels for proper GRPO loss computation
"""
batches = []
_logged_logprob_warning = False
empty_polls = 0
while True:
data = get_batch(url=atropos_url)
if data["batch"] is not None:
empty_polls = 0
num_groups = len(data["batch"])
num_sequences = sum(len(item["tokens"]) for item in data["batch"])
max_seq_len = max(
max(len(seq) for seq in item["tokens"]) for item in data["batch"]
)
print(
" [Data] received API batch: "
f"groups={num_groups} sequences={num_sequences} max_seq_len={max_seq_len}",
flush=True,
)
# DEBUG: Check if inference_logprobs exists in the data
if not _logged_logprob_warning:
has_logprobs = any(
"inference_logprobs" in item for item in data["batch"]
)
if has_logprobs:
# Check if they're non-empty
sample_item = next(
(
item
for item in data["batch"]
if "inference_logprobs" in item
),
None,
)
if sample_item and sample_item.get("inference_logprobs"):
sample_lp = (
sample_item["inference_logprobs"][0]
if sample_item["inference_logprobs"]
else []
)
print(
f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})"
)
else:
print(
" [Data] ⚠ inference_logprobs key exists but is empty!"
)
else:
print(" [Data] ⚠ NO inference_logprobs in batch data!")
print(
f" [Data] Keys in first item: {list(data['batch'][0].keys())}"
)
_logged_logprob_warning = True
# Process and accumulate batches (now includes batched inference logprobs)
print(" [Data] padding / batching API payload...", flush=True)
(
token_batches,
label_batches,
@ -441,12 +392,6 @@ def get_data(
distill_token_id_batches,
distill_logprob_batches,
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
batch_shapes = [tuple(tb.shape) for tb in token_batches]
print(
" [Data] pad_data_to_good_offset done: "
f"micro_batches={len(token_batches)} token_batch_shapes={batch_shapes}",
flush=True,
)
# Include inference logprob batches in the tuple
batches.append(
@ -463,17 +408,7 @@ def get_data(
elif len(batches) > 0:
# Return accumulated batches when no more data
print(
f" [Data] returning {len(batches)} assembled trainer batch tuple(s)",
flush=True,
)
return batches, None
else:
# Wait for data
empty_polls += 1
if empty_polls == 1 or empty_polls % 30 == 0:
print(
f" [Data] no batch ready yet (polls_without_data={empty_polls})",
flush=True,
)
time.sleep(1)