mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
clean log
This commit is contained in:
parent
d1b0dee8f7
commit
600c54f5f8
7 changed files with 15 additions and 206 deletions
|
|
@ -4,7 +4,7 @@ import time
|
|||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
|
@ -351,7 +351,7 @@ async def info():
|
|||
|
||||
|
||||
@app.get("/batch")
|
||||
async def get_batch(request: Request):
|
||||
async def get_batch():
|
||||
# Check if trainer has registered first
|
||||
if not hasattr(app.state, "started"):
|
||||
return {
|
||||
|
|
@ -363,27 +363,8 @@ async def get_batch(request: Request):
|
|||
if not app.state.started:
|
||||
app.state.started = True
|
||||
|
||||
client = request.client
|
||||
client_addr = (
|
||||
f"{client.host}:{client.port}" if client is not None else "unknown-client"
|
||||
)
|
||||
client_tag = request.headers.get("x-atropos-client", "unknown")
|
||||
client_pid = request.headers.get("x-atropos-pid", "unknown")
|
||||
|
||||
if len(app.state.curr_batch) > 0:
|
||||
curr_batch = app.state.curr_batch.pop()
|
||||
logger.warning(
|
||||
"API /batch returning prebuilt batch to client=%s pid=%s addr=%s: "
|
||||
"groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s",
|
||||
client_tag,
|
||||
client_pid,
|
||||
client_addr,
|
||||
len(curr_batch),
|
||||
sum(len(x["tokens"]) for x in curr_batch),
|
||||
len(app.state.curr_batch),
|
||||
len(app.state.queue),
|
||||
)
|
||||
return {"batch": curr_batch}
|
||||
return {"batch": app.state.curr_batch.pop()}
|
||||
else:
|
||||
new_batches = []
|
||||
# Check if any envs have minimum allocations
|
||||
|
|
@ -413,21 +394,6 @@ async def get_batch(request: Request):
|
|||
)
|
||||
steps_to_take = len(new_batches)
|
||||
if steps_to_take == 0:
|
||||
now = time.time()
|
||||
last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0)
|
||||
if now - last_empty_log > 30:
|
||||
logger.warning(
|
||||
"API /batch no full batch ready for client=%s pid=%s addr=%s: "
|
||||
"queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s",
|
||||
client_tag,
|
||||
client_pid,
|
||||
client_addr,
|
||||
len(app.state.queue),
|
||||
sum(len(x.get("tokens", [])) for x in app.state.queue),
|
||||
len(app.state.curr_batch),
|
||||
getattr(app.state, "batchsize", -1),
|
||||
)
|
||||
app.state._last_empty_batch_log = now
|
||||
return {"batch": None}
|
||||
app.state.status_dict["step"] += steps_to_take
|
||||
# chunk it
|
||||
|
|
@ -435,18 +401,9 @@ async def get_batch(request: Request):
|
|||
app.state.curr_batch.append(batch)
|
||||
curr_batch = app.state.curr_batch.pop()
|
||||
# check length before sending
|
||||
logger.warning(
|
||||
"API /batch built %s trainer batch(es); returning one to client=%s pid=%s addr=%s "
|
||||
"with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s",
|
||||
steps_to_take,
|
||||
client_tag,
|
||||
client_pid,
|
||||
client_addr,
|
||||
len(curr_batch),
|
||||
logger.info(
|
||||
"Sending batch of %s sequences",
|
||||
sum(len(x["tokens"]) for x in curr_batch),
|
||||
len(app.state.curr_batch),
|
||||
len(app.state.queue),
|
||||
app.state.status_dict["step"],
|
||||
)
|
||||
return {"batch": curr_batch}
|
||||
|
||||
|
|
|
|||
|
|
@ -907,7 +907,7 @@ class BaseEnv(ABC):
|
|||
"ensure your trainer handles this appropriately."
|
||||
)
|
||||
elif abort_on_any_max_length_exceeded and any(
|
||||
[len(x) > self.max_token_len for x in group["tokens"]]
|
||||
[len(x) >= self.max_token_len for x in group["tokens"]]
|
||||
):
|
||||
logger.warning("Token length is too long in a group, skipping...")
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -447,33 +447,14 @@ class ManagedServer:
|
|||
if not self.track_tree and self.tokenizer is not None:
|
||||
input_ids = self._compute_input_ids(prompt, extending_node)
|
||||
completion_kwargs["input_ids"] = input_ids
|
||||
logger.warning(
|
||||
"managed_server chat_completion prepared input_ids=%s extending=%s",
|
||||
len(input_ids),
|
||||
extending_node is not None,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"managed_server chat_completion using prompt passthrough track_tree=%s tokenizer=%s",
|
||||
self.track_tree,
|
||||
self.tokenizer is not None,
|
||||
)
|
||||
|
||||
# Call the tokens and logprobs wrapper directly
|
||||
logger.warning(
|
||||
"managed_server chat_completion calling backend completion wrapper"
|
||||
)
|
||||
(
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons,
|
||||
) = await self.server.tokens_and_logprobs_completion(**completion_kwargs)
|
||||
logger.warning(
|
||||
"managed_server chat_completion backend returned prompt_tokens=%s outputs=%s",
|
||||
len(prompt_tokens),
|
||||
len(output_tokens_list),
|
||||
)
|
||||
|
||||
# Track each completion and build choices
|
||||
n = len(output_tokens_list)
|
||||
|
|
|
|||
|
|
@ -106,13 +106,6 @@ class ServerManager:
|
|||
self.servers = [ServerHarness()]
|
||||
return
|
||||
if not isinstance(configs, list):
|
||||
logger.warning(
|
||||
"ServerManager: configs is NOT a list (type=%s). "
|
||||
"Using auto-generated URLs (template mode). "
|
||||
"Passed base_url=%s will be IGNORED.",
|
||||
type(configs).__name__,
|
||||
getattr(configs, "base_url", "N/A"),
|
||||
)
|
||||
urls = []
|
||||
if os.environ.get("SLURM_JOB_NODELIST", None) is not None:
|
||||
nodelist = (
|
||||
|
|
@ -155,21 +148,11 @@ class ServerManager:
|
|||
server_class(config, reasoning_config=reasoning_config)
|
||||
for config in openai_configs
|
||||
]
|
||||
logger.warning(
|
||||
"ServerManager: auto-generated %s server(s) at URLs: %s",
|
||||
len(self.servers),
|
||||
[c.base_url for c in openai_configs],
|
||||
)
|
||||
elif not slurm:
|
||||
self.servers = [
|
||||
server_class(config, reasoning_config=reasoning_config)
|
||||
for config in configs
|
||||
]
|
||||
logger.warning(
|
||||
"ServerManager: using %s explicit config(s) at URLs: %s",
|
||||
len(self.servers),
|
||||
[c.base_url for c in configs],
|
||||
)
|
||||
else:
|
||||
nodelist = (
|
||||
os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}')
|
||||
|
|
|
|||
|
|
@ -193,14 +193,6 @@ class VLLMServer(APIServer):
|
|||
# Prepare request for VLLM native API
|
||||
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
|
||||
request_data.update(kwargs)
|
||||
logger.warning(
|
||||
"vllm_server completion POST start base_url=%s prompt_tokens=%s n=%s max_tokens=%s temperature=%s",
|
||||
self.config.base_url,
|
||||
len(prompt_tokens),
|
||||
request_data.get("n"),
|
||||
request_data.get("max_tokens"),
|
||||
request_data.get("temperature"),
|
||||
)
|
||||
|
||||
# Make async request to VLLM /generate endpoint
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
|
@ -216,11 +208,6 @@ class VLLMServer(APIServer):
|
|||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
logger.warning(
|
||||
"vllm_server completion POST done outputs=%s finish_reasons=%s",
|
||||
len(results.get("logprobs", [])),
|
||||
len(results.get("finish_reasons", [])),
|
||||
)
|
||||
output_tokens_list = []
|
||||
output_logprobs_list = []
|
||||
finish_reasons_list = []
|
||||
|
|
@ -330,13 +317,6 @@ class VLLMServer(APIServer):
|
|||
request_data["temperature"] = 0.0
|
||||
request_data["top_p"] = 1.0
|
||||
request_data.setdefault("max_tokens", 1)
|
||||
logger.warning(
|
||||
"vllm_server get_logprobs POST start base_url=%s prompt_tokens=%s top_k=%s max_tokens=%s",
|
||||
self.config.base_url,
|
||||
len(prompt_tokens),
|
||||
top_k,
|
||||
request_data.get("max_tokens"),
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
|
@ -351,10 +331,6 @@ class VLLMServer(APIServer):
|
|||
) as response:
|
||||
response.raise_for_status()
|
||||
results = await response.json()
|
||||
logger.warning(
|
||||
"vllm_server get_logprobs POST done prompt_logprobs_present=%s",
|
||||
results.get("prompt_logprobs") is not None,
|
||||
)
|
||||
|
||||
raw_prompt_logprobs = results.get("prompt_logprobs")
|
||||
if raw_prompt_logprobs is None:
|
||||
|
|
@ -451,10 +427,6 @@ def resolve_openai_configs(
|
|||
elif isinstance(default_server_configs, list):
|
||||
server_configs = [final_openai_config]
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unexpected type for default_server_configs: {type(default_server_configs)}. "
|
||||
f"Proceeding with single OpenAI server configuration based on merged settings."
|
||||
)
|
||||
server_configs = [final_openai_config]
|
||||
|
||||
return server_configs
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue