training kernel

This commit is contained in:
Jai Suphavadeeprasit 2026-03-12 14:51:28 -04:00
parent 7ec622a098
commit a43b0b7e72
3 changed files with 67 additions and 3 deletions

View file

@ -364,7 +364,15 @@ async def get_batch():
app.state.started = True
if len(app.state.curr_batch) > 0:
return {"batch": app.state.curr_batch.pop()}
curr_batch = app.state.curr_batch.pop()
logger.warning(
"API /batch returning prebuilt batch: groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s",
len(curr_batch),
sum(len(x["tokens"]) for x in curr_batch),
len(app.state.curr_batch),
len(app.state.queue),
)
return {"batch": curr_batch}
else:
new_batches = []
# Check if any envs have minimum allocations
@ -394,6 +402,17 @@ async def get_batch():
)
steps_to_take = len(new_batches)
if steps_to_take == 0:
now = time.time()
last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0)
if now - last_empty_log > 30:
logger.warning(
"API /batch no full batch ready: queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s",
len(app.state.queue),
sum(len(x.get("tokens", [])) for x in app.state.queue),
len(app.state.curr_batch),
getattr(app.state, "batchsize", -1),
)
app.state._last_empty_batch_log = now
return {"batch": None}
app.state.status_dict["step"] += steps_to_take
# chunk it
@ -401,9 +420,14 @@ async def get_batch():
app.state.curr_batch.append(batch)
curr_batch = app.state.curr_batch.pop()
# check length before sending
logger.info(
"Sending batch of %s sequences",
logger.warning(
"API /batch built %s trainer batch(es); returning one with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s",
steps_to_take,
len(curr_batch),
sum(len(x["tokens"]) for x in curr_batch),
len(app.state.curr_batch),
len(app.state.queue),
app.state.status_dict["step"],
)
return {"batch": curr_batch}