mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
investigating weird training issue
This commit is contained in:
parent
a43b0b7e72
commit
690e670e64
2 changed files with 33 additions and 6 deletions
|
|
@ -4,7 +4,7 @@ import time
|
|||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
|
@ -351,7 +351,7 @@ async def info():
|
|||
|
||||
|
||||
@app.get("/batch")
|
||||
async def get_batch():
|
||||
async def get_batch(request: Request):
|
||||
# Check if trainer has registered first
|
||||
if not hasattr(app.state, "started"):
|
||||
return {
|
||||
|
|
@ -363,10 +363,21 @@ async def get_batch():
|
|||
if not app.state.started:
|
||||
app.state.started = True
|
||||
|
||||
client = request.client
|
||||
client_addr = (
|
||||
f"{client.host}:{client.port}" if client is not None else "unknown-client"
|
||||
)
|
||||
client_tag = request.headers.get("x-atropos-client", "unknown")
|
||||
client_pid = request.headers.get("x-atropos-pid", "unknown")
|
||||
|
||||
if len(app.state.curr_batch) > 0:
|
||||
curr_batch = app.state.curr_batch.pop()
|
||||
logger.warning(
|
||||
"API /batch returning prebuilt batch: groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s",
|
||||
"API /batch returning prebuilt batch to client=%s pid=%s addr=%s: "
|
||||
"groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s",
|
||||
client_tag,
|
||||
client_pid,
|
||||
client_addr,
|
||||
len(curr_batch),
|
||||
sum(len(x["tokens"]) for x in curr_batch),
|
||||
len(app.state.curr_batch),
|
||||
|
|
@ -406,7 +417,11 @@ async def get_batch():
|
|||
last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0)
|
||||
if now - last_empty_log > 30:
|
||||
logger.warning(
|
||||
"API /batch no full batch ready: queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s",
|
||||
"API /batch no full batch ready for client=%s pid=%s addr=%s: "
|
||||
"queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s",
|
||||
client_tag,
|
||||
client_pid,
|
||||
client_addr,
|
||||
len(app.state.queue),
|
||||
sum(len(x.get("tokens", [])) for x in app.state.queue),
|
||||
len(app.state.curr_batch),
|
||||
|
|
@ -421,8 +436,12 @@ async def get_batch():
|
|||
curr_batch = app.state.curr_batch.pop()
|
||||
# check length before sending
|
||||
logger.warning(
|
||||
"API /batch built %s trainer batch(es); returning one with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s",
|
||||
"API /batch built %s trainer batch(es); returning one to client=%s pid=%s addr=%s "
|
||||
"with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s",
|
||||
steps_to_take,
|
||||
client_tag,
|
||||
client_pid,
|
||||
client_addr,
|
||||
len(curr_batch),
|
||||
sum(len(x["tokens"]) for x in curr_batch),
|
||||
len(app.state.curr_batch),
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ Handles communication with the Atropos API server for:
|
|||
- Batch retrieval
|
||||
"""
|
||||
|
||||
import os
|
||||
import time as _time
|
||||
|
||||
import requests
|
||||
|
|
@ -99,7 +100,14 @@ def get_batch(url: str = "http://localhost:8000"):
|
|||
Raises:
|
||||
RuntimeError: If trainer is not registered or other API error
|
||||
"""
|
||||
data = requests.get(f"{url}/batch", timeout=10).json()
|
||||
data = requests.get(
|
||||
f"{url}/batch",
|
||||
headers={
|
||||
"X-Atropos-Client": "trainer",
|
||||
"X-Atropos-Pid": str(os.getpid()),
|
||||
},
|
||||
timeout=10,
|
||||
).json()
|
||||
|
||||
# Check if there was an error (trainer not registered)
|
||||
if data.get("status") == "error":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue