mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
training kernel
This commit is contained in:
parent
7ec622a098
commit
a43b0b7e72
3 changed files with 67 additions and 3 deletions
|
|
@ -364,7 +364,15 @@ async def get_batch():
|
|||
app.state.started = True
|
||||
|
||||
if len(app.state.curr_batch) > 0:
|
||||
return {"batch": app.state.curr_batch.pop()}
|
||||
curr_batch = app.state.curr_batch.pop()
|
||||
logger.warning(
|
||||
"API /batch returning prebuilt batch: groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s",
|
||||
len(curr_batch),
|
||||
sum(len(x["tokens"]) for x in curr_batch),
|
||||
len(app.state.curr_batch),
|
||||
len(app.state.queue),
|
||||
)
|
||||
return {"batch": curr_batch}
|
||||
else:
|
||||
new_batches = []
|
||||
# Check if any envs have minimum allocations
|
||||
|
|
@ -394,6 +402,17 @@ async def get_batch():
|
|||
)
|
||||
steps_to_take = len(new_batches)
|
||||
if steps_to_take == 0:
|
||||
now = time.time()
|
||||
last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0)
|
||||
if now - last_empty_log > 30:
|
||||
logger.warning(
|
||||
"API /batch no full batch ready: queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s",
|
||||
len(app.state.queue),
|
||||
sum(len(x.get("tokens", [])) for x in app.state.queue),
|
||||
len(app.state.curr_batch),
|
||||
getattr(app.state, "batchsize", -1),
|
||||
)
|
||||
app.state._last_empty_batch_log = now
|
||||
return {"batch": None}
|
||||
app.state.status_dict["step"] += steps_to_take
|
||||
# chunk it
|
||||
|
|
@ -401,9 +420,14 @@ async def get_batch():
|
|||
app.state.curr_batch.append(batch)
|
||||
curr_batch = app.state.curr_batch.pop()
|
||||
# check length before sending
|
||||
logger.info(
|
||||
"Sending batch of %s sequences",
|
||||
logger.warning(
|
||||
"API /batch built %s trainer batch(es); returning one with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s",
|
||||
steps_to_take,
|
||||
len(curr_batch),
|
||||
sum(len(x["tokens"]) for x in curr_batch),
|
||||
len(app.state.curr_batch),
|
||||
len(app.state.queue),
|
||||
app.state.status_dict["step"],
|
||||
)
|
||||
return {"batch": curr_batch}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue