clean log

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 12:09:08 -04:00
parent d1b0dee8f7
commit 600c54f5f8
7 changed files with 15 additions and 206 deletions

View file

@ -4,7 +4,7 @@ import time
import uuid
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, Request, status
from fastapi import FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import PlainTextResponse
@ -351,7 +351,7 @@ async def info():
@app.get("/batch")
async def get_batch(request: Request):
async def get_batch():
# Check if trainer has registered first
if not hasattr(app.state, "started"):
return {
@ -363,27 +363,8 @@ async def get_batch(request: Request):
if not app.state.started:
app.state.started = True
client = request.client
client_addr = (
f"{client.host}:{client.port}" if client is not None else "unknown-client"
)
client_tag = request.headers.get("x-atropos-client", "unknown")
client_pid = request.headers.get("x-atropos-pid", "unknown")
if len(app.state.curr_batch) > 0:
curr_batch = app.state.curr_batch.pop()
logger.warning(
"API /batch returning prebuilt batch to client=%s pid=%s addr=%s: "
"groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s",
client_tag,
client_pid,
client_addr,
len(curr_batch),
sum(len(x["tokens"]) for x in curr_batch),
len(app.state.curr_batch),
len(app.state.queue),
)
return {"batch": curr_batch}
return {"batch": app.state.curr_batch.pop()}
else:
new_batches = []
# Check if any envs have minimum allocations
@ -413,21 +394,6 @@ async def get_batch(request: Request):
)
steps_to_take = len(new_batches)
if steps_to_take == 0:
now = time.time()
last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0)
if now - last_empty_log > 30:
logger.warning(
"API /batch no full batch ready for client=%s pid=%s addr=%s: "
"queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s",
client_tag,
client_pid,
client_addr,
len(app.state.queue),
sum(len(x.get("tokens", [])) for x in app.state.queue),
len(app.state.curr_batch),
getattr(app.state, "batchsize", -1),
)
app.state._last_empty_batch_log = now
return {"batch": None}
app.state.status_dict["step"] += steps_to_take
# chunk it
@ -435,18 +401,9 @@ async def get_batch(request: Request):
app.state.curr_batch.append(batch)
curr_batch = app.state.curr_batch.pop()
# check length before sending
logger.warning(
"API /batch built %s trainer batch(es); returning one to client=%s pid=%s addr=%s "
"with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s",
steps_to_take,
client_tag,
client_pid,
client_addr,
len(curr_batch),
logger.info(
"Sending batch of %s sequences",
sum(len(x["tokens"]) for x in curr_batch),
len(app.state.curr_batch),
len(app.state.queue),
app.state.status_dict["step"],
)
return {"batch": curr_batch}

View file

@ -907,7 +907,7 @@ class BaseEnv(ABC):
"ensure your trainer handles this appropriately."
)
elif abort_on_any_max_length_exceeded and any(
[len(x) > self.max_token_len for x in group["tokens"]]
[len(x) >= self.max_token_len for x in group["tokens"]]
):
logger.warning("Token length is too long in a group, skipping...")
continue

View file

@ -447,33 +447,14 @@ class ManagedServer:
if not self.track_tree and self.tokenizer is not None:
input_ids = self._compute_input_ids(prompt, extending_node)
completion_kwargs["input_ids"] = input_ids
logger.warning(
"managed_server chat_completion prepared input_ids=%s extending=%s",
len(input_ids),
extending_node is not None,
)
else:
logger.warning(
"managed_server chat_completion using prompt passthrough track_tree=%s tokenizer=%s",
self.track_tree,
self.tokenizer is not None,
)
# Call the tokens and logprobs wrapper directly
logger.warning(
"managed_server chat_completion calling backend completion wrapper"
)
(
prompt_tokens,
output_tokens_list,
output_logprobs_list,
finish_reasons,
) = await self.server.tokens_and_logprobs_completion(**completion_kwargs)
logger.warning(
"managed_server chat_completion backend returned prompt_tokens=%s outputs=%s",
len(prompt_tokens),
len(output_tokens_list),
)
# Track each completion and build choices
n = len(output_tokens_list)

View file

@ -106,13 +106,6 @@ class ServerManager:
self.servers = [ServerHarness()]
return
if not isinstance(configs, list):
logger.warning(
"ServerManager: configs is NOT a list (type=%s). "
"Using auto-generated URLs (template mode). "
"Passed base_url=%s will be IGNORED.",
type(configs).__name__,
getattr(configs, "base_url", "N/A"),
)
urls = []
if os.environ.get("SLURM_JOB_NODELIST", None) is not None:
nodelist = (
@ -155,21 +148,11 @@ class ServerManager:
server_class(config, reasoning_config=reasoning_config)
for config in openai_configs
]
logger.warning(
"ServerManager: auto-generated %s server(s) at URLs: %s",
len(self.servers),
[c.base_url for c in openai_configs],
)
elif not slurm:
self.servers = [
server_class(config, reasoning_config=reasoning_config)
for config in configs
]
logger.warning(
"ServerManager: using %s explicit config(s) at URLs: %s",
len(self.servers),
[c.base_url for c in configs],
)
else:
nodelist = (
os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}')

View file

@ -193,14 +193,6 @@ class VLLMServer(APIServer):
# Prepare request for VLLM native API
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
request_data.update(kwargs)
logger.warning(
"vllm_server completion POST start base_url=%s prompt_tokens=%s n=%s max_tokens=%s temperature=%s",
self.config.base_url,
len(prompt_tokens),
request_data.get("n"),
request_data.get("max_tokens"),
request_data.get("temperature"),
)
# Make async request to VLLM /generate endpoint
async with aiohttp.ClientSession() as session:
@ -216,11 +208,6 @@ class VLLMServer(APIServer):
) as response:
response.raise_for_status()
results = await response.json()
logger.warning(
"vllm_server completion POST done outputs=%s finish_reasons=%s",
len(results.get("logprobs", [])),
len(results.get("finish_reasons", [])),
)
output_tokens_list = []
output_logprobs_list = []
finish_reasons_list = []
@ -330,13 +317,6 @@ class VLLMServer(APIServer):
request_data["temperature"] = 0.0
request_data["top_p"] = 1.0
request_data.setdefault("max_tokens", 1)
logger.warning(
"vllm_server get_logprobs POST start base_url=%s prompt_tokens=%s top_k=%s max_tokens=%s",
self.config.base_url,
len(prompt_tokens),
top_k,
request_data.get("max_tokens"),
)
async with aiohttp.ClientSession() as session:
async with session.post(
@ -351,10 +331,6 @@ class VLLMServer(APIServer):
) as response:
response.raise_for_status()
results = await response.json()
logger.warning(
"vllm_server get_logprobs POST done prompt_logprobs_present=%s",
results.get("prompt_logprobs") is not None,
)
raw_prompt_logprobs = results.get("prompt_logprobs")
if raw_prompt_logprobs is None:
@ -451,10 +427,6 @@ def resolve_openai_configs(
elif isinstance(default_server_configs, list):
server_configs = [final_openai_config]
else:
logger.warning(
f"Unexpected type for default_server_configs: {type(default_server_configs)}. "
f"Proceeding with single OpenAI server configuration based on merged settings."
)
server_configs = [final_openai_config]
return server_configs