clean log

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 12:09:08 -04:00
parent d1b0dee8f7
commit 600c54f5f8
7 changed files with 15 additions and 206 deletions

View file

@ -907,7 +907,7 @@ class BaseEnv(ABC):
"ensure your trainer handles this appropriately."
)
elif abort_on_any_max_length_exceeded and any(
[len(x) > self.max_token_len for x in group["tokens"]]
[len(x) >= self.max_token_len for x in group["tokens"]]
):
logger.warning("Token length is too long in a group, skipping...")
continue

View file

@ -447,33 +447,14 @@ class ManagedServer:
if not self.track_tree and self.tokenizer is not None:
input_ids = self._compute_input_ids(prompt, extending_node)
completion_kwargs["input_ids"] = input_ids
logger.warning(
"managed_server chat_completion prepared input_ids=%s extending=%s",
len(input_ids),
extending_node is not None,
)
else:
logger.warning(
"managed_server chat_completion using prompt passthrough track_tree=%s tokenizer=%s",
self.track_tree,
self.tokenizer is not None,
)
# Call the tokens and logprobs wrapper directly
logger.warning(
"managed_server chat_completion calling backend completion wrapper"
)
(
prompt_tokens,
output_tokens_list,
output_logprobs_list,
finish_reasons,
) = await self.server.tokens_and_logprobs_completion(**completion_kwargs)
logger.warning(
"managed_server chat_completion backend returned prompt_tokens=%s outputs=%s",
len(prompt_tokens),
len(output_tokens_list),
)
# Track each completion and build choices
n = len(output_tokens_list)

View file

@ -106,13 +106,6 @@ class ServerManager:
self.servers = [ServerHarness()]
return
if not isinstance(configs, list):
logger.warning(
"ServerManager: configs is NOT a list (type=%s). "
"Using auto-generated URLs (template mode). "
"Passed base_url=%s will be IGNORED.",
type(configs).__name__,
getattr(configs, "base_url", "N/A"),
)
urls = []
if os.environ.get("SLURM_JOB_NODELIST", None) is not None:
nodelist = (
@ -155,21 +148,11 @@ class ServerManager:
server_class(config, reasoning_config=reasoning_config)
for config in openai_configs
]
logger.warning(
"ServerManager: auto-generated %s server(s) at URLs: %s",
len(self.servers),
[c.base_url for c in openai_configs],
)
elif not slurm:
self.servers = [
server_class(config, reasoning_config=reasoning_config)
for config in configs
]
logger.warning(
"ServerManager: using %s explicit config(s) at URLs: %s",
len(self.servers),
[c.base_url for c in configs],
)
else:
nodelist = (
os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}')

View file

@ -193,14 +193,6 @@ class VLLMServer(APIServer):
# Prepare request for VLLM native API
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
request_data.update(kwargs)
logger.warning(
"vllm_server completion POST start base_url=%s prompt_tokens=%s n=%s max_tokens=%s temperature=%s",
self.config.base_url,
len(prompt_tokens),
request_data.get("n"),
request_data.get("max_tokens"),
request_data.get("temperature"),
)
# Make async request to VLLM /generate endpoint
async with aiohttp.ClientSession() as session:
@ -216,11 +208,6 @@ class VLLMServer(APIServer):
) as response:
response.raise_for_status()
results = await response.json()
logger.warning(
"vllm_server completion POST done outputs=%s finish_reasons=%s",
len(results.get("logprobs", [])),
len(results.get("finish_reasons", [])),
)
output_tokens_list = []
output_logprobs_list = []
finish_reasons_list = []
@ -330,13 +317,6 @@ class VLLMServer(APIServer):
request_data["temperature"] = 0.0
request_data["top_p"] = 1.0
request_data.setdefault("max_tokens", 1)
logger.warning(
"vllm_server get_logprobs POST start base_url=%s prompt_tokens=%s top_k=%s max_tokens=%s",
self.config.base_url,
len(prompt_tokens),
top_k,
request_data.get("max_tokens"),
)
async with aiohttp.ClientSession() as session:
async with session.post(
@ -351,10 +331,6 @@ class VLLMServer(APIServer):
) as response:
response.raise_for_status()
results = await response.json()
logger.warning(
"vllm_server get_logprobs POST done prompt_logprobs_present=%s",
results.get("prompt_logprobs") is not None,
)
raw_prompt_logprobs = results.get("prompt_logprobs")
if raw_prompt_logprobs is None:
@ -451,10 +427,6 @@ def resolve_openai_configs(
elif isinstance(default_server_configs, list):
server_configs = [final_openai_config]
else:
logger.warning(
f"Unexpected type for default_server_configs: {type(default_server_configs)}. "
f"Proceeding with single OpenAI server configuration based on merged settings."
)
server_configs = [final_openai_config]
return server_configs