investigating weird training issue

This commit is contained in:
Jai Suphavadeeprasit 2026-03-12 16:11:06 -04:00
parent a43b0b7e72
commit 690e670e64
2 changed files with 33 additions and 6 deletions

View file

@ -4,7 +4,7 @@ import time
import uuid
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, status
from fastapi import FastAPI, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import PlainTextResponse
@ -351,7 +351,7 @@ async def info():
@app.get("/batch")
async def get_batch():
async def get_batch(request: Request):
# Check if trainer has registered first
if not hasattr(app.state, "started"):
return {
@ -363,10 +363,21 @@ async def get_batch():
if not app.state.started:
app.state.started = True
client = request.client
client_addr = (
f"{client.host}:{client.port}" if client is not None else "unknown-client"
)
client_tag = request.headers.get("x-atropos-client", "unknown")
client_pid = request.headers.get("x-atropos-pid", "unknown")
if len(app.state.curr_batch) > 0:
curr_batch = app.state.curr_batch.pop()
logger.warning(
"API /batch returning prebuilt batch: groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s",
"API /batch returning prebuilt batch to client=%s pid=%s addr=%s: "
"groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s",
client_tag,
client_pid,
client_addr,
len(curr_batch),
sum(len(x["tokens"]) for x in curr_batch),
len(app.state.curr_batch),
@ -406,7 +417,11 @@ async def get_batch():
last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0)
if now - last_empty_log > 30:
logger.warning(
"API /batch no full batch ready: queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s",
"API /batch no full batch ready for client=%s pid=%s addr=%s: "
"queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s",
client_tag,
client_pid,
client_addr,
len(app.state.queue),
sum(len(x.get("tokens", [])) for x in app.state.queue),
len(app.state.curr_batch),
@ -421,8 +436,12 @@ async def get_batch():
curr_batch = app.state.curr_batch.pop()
# check length before sending
logger.warning(
"API /batch built %s trainer batch(es); returning one with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s",
"API /batch built %s trainer batch(es); returning one to client=%s pid=%s addr=%s "
"with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s",
steps_to_take,
client_tag,
client_pid,
client_addr,
len(curr_batch),
sum(len(x["tokens"]) for x in curr_batch),
len(app.state.curr_batch),

View file

@ -7,6 +7,7 @@ Handles communication with the Atropos API server for:
- Batch retrieval
"""
import os
import time as _time
import requests
@ -99,7 +100,14 @@ def get_batch(url: str = "http://localhost:8000"):
Raises:
RuntimeError: If trainer is not registered or other API error
"""
data = requests.get(f"{url}/batch", timeout=10).json()
data = requests.get(
f"{url}/batch",
headers={
"X-Atropos-Client": "trainer",
"X-Atropos-Pid": str(os.getpid()),
},
timeout=10,
).json()
# Check if there was an error (trainer not registered)
if data.get("status") == "error":