clean log

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 12:09:08 -04:00
parent d1b0dee8f7
commit 600c54f5f8
7 changed files with 15 additions and 206 deletions

View file

@ -4,7 +4,7 @@ import time
import uuid
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, Request, status
from fastapi import FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import PlainTextResponse
@ -351,7 +351,7 @@ async def info():
@app.get("/batch")
async def get_batch(request: Request):
async def get_batch():
# Check if trainer has registered first
if not hasattr(app.state, "started"):
return {
@ -363,27 +363,8 @@ async def get_batch(request: Request):
if not app.state.started:
app.state.started = True
client = request.client
client_addr = (
f"{client.host}:{client.port}" if client is not None else "unknown-client"
)
client_tag = request.headers.get("x-atropos-client", "unknown")
client_pid = request.headers.get("x-atropos-pid", "unknown")
if len(app.state.curr_batch) > 0:
curr_batch = app.state.curr_batch.pop()
logger.warning(
"API /batch returning prebuilt batch to client=%s pid=%s addr=%s: "
"groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s",
client_tag,
client_pid,
client_addr,
len(curr_batch),
sum(len(x["tokens"]) for x in curr_batch),
len(app.state.curr_batch),
len(app.state.queue),
)
return {"batch": curr_batch}
return {"batch": app.state.curr_batch.pop()}
else:
new_batches = []
# Check if any envs have minimum allocations
@ -413,21 +394,6 @@ async def get_batch(request: Request):
)
steps_to_take = len(new_batches)
if steps_to_take == 0:
now = time.time()
last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0)
if now - last_empty_log > 30:
logger.warning(
"API /batch no full batch ready for client=%s pid=%s addr=%s: "
"queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s",
client_tag,
client_pid,
client_addr,
len(app.state.queue),
sum(len(x.get("tokens", [])) for x in app.state.queue),
len(app.state.curr_batch),
getattr(app.state, "batchsize", -1),
)
app.state._last_empty_batch_log = now
return {"batch": None}
app.state.status_dict["step"] += steps_to_take
# chunk it
@ -435,18 +401,9 @@ async def get_batch(request: Request):
app.state.curr_batch.append(batch)
curr_batch = app.state.curr_batch.pop()
# check length before sending
logger.warning(
"API /batch built %s trainer batch(es); returning one to client=%s pid=%s addr=%s "
"with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s",
steps_to_take,
client_tag,
client_pid,
client_addr,
len(curr_batch),
logger.info(
"Sending batch of %s sequences",
sum(len(x["tokens"]) for x in curr_batch),
len(app.state.curr_batch),
len(app.state.queue),
app.state.status_dict["step"],
)
return {"batch": curr_batch}