logger changes

This commit is contained in:
Jai Suphavadeeprasit 2026-02-27 16:13:51 -05:00
parent 64d3ee1bd6
commit 35587cbdc0
3 changed files with 63 additions and 52 deletions

View file

@ -1,4 +1,5 @@
import gzip import gzip
import logging
import time import time
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -25,6 +26,7 @@ MIN_ENV_WEIGHT = (
# Message import removed - using Dict[str, Any] for more flexible validation # Message import removed - using Dict[str, Any] for more flexible validation
app = FastAPI(title="AtroposLib API") app = FastAPI(title="AtroposLib API")
logger = logging.getLogger(__name__)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -391,7 +393,10 @@ async def get_batch():
app.state.curr_batch.append(batch) app.state.curr_batch.append(batch)
curr_batch = app.state.curr_batch.pop() curr_batch = app.state.curr_batch.pop()
# check length before sending # check length before sending
print(f"Sending batch of {sum(len(x['tokens']) for x in curr_batch)} sequences") logger.info(
"Sending batch of %s sequences",
sum(len(x["tokens"]) for x in curr_batch),
)
return {"batch": curr_batch} return {"batch": curr_batch}

View file

@ -286,7 +286,7 @@ class BaseEnv(ABC):
counter += 1 counter += 1
path_changed = True path_changed = True
if path_changed: if path_changed:
print( logger.info(
f"Changed data path to {self.config.data_path_to_save_groups} because {original_path} already exists." # noqa: E501 f"Changed data path to {self.config.data_path_to_save_groups} because {original_path} already exists." # noqa: E501
) )
@ -364,7 +364,7 @@ class BaseEnv(ABC):
to_postprocess["group_overrides"] = {} to_postprocess["group_overrides"] = {}
to_postprocess["overrides"] = [] to_postprocess["overrides"] = []
to_postprocess["images"] = [] to_postprocess["images"] = []
print("Processing results") logger.debug("Processing results")
for result in results: for result in results:
to_postprocess["tokens"].append(result[0]["tokens"]) to_postprocess["tokens"].append(result[0]["tokens"])
to_postprocess["masks"].append(result[0]["masks"]) to_postprocess["masks"].append(result[0]["masks"])
@ -444,7 +444,7 @@ class BaseEnv(ABC):
setattr(self, key, data[key]) setattr(self, key, data[key])
def save_checkpoint(self, step, data=None): def save_checkpoint(self, step, data=None):
print(f"Saving checkpoint at step {step} with data {data}") logger.info("Saving checkpoint at step %s with data %s", step, data)
if data is None: if data is None:
# Don't have anything to save, abort # Don't have anything to save, abort
return return
@ -542,7 +542,7 @@ class BaseEnv(ABC):
self.config.total_steps = data["num_steps"] self.config.total_steps = data["num_steps"]
if self.config.total_steps == -1: if self.config.total_steps == -1:
raise ValueError("Total steps not set in config or server!") raise ValueError("Total steps not set in config or server!")
print( logger.info(
f"Initialized env with id {self.env_id}: " f"Initialized env with id {self.env_id}: "
f"curr_step: {self.curr_step}, " f"curr_step: {self.curr_step}, "
f"checkpoint_dir: {self.checkpoint_dir}, " f"checkpoint_dir: {self.checkpoint_dir}, "
@ -779,7 +779,7 @@ class BaseEnv(ABC):
with open(filepath, "w") as f: with open(filepath, "w") as f:
json.dump(eval_result, f, indent=2) json.dump(eval_result, f, indent=2)
print(f"Evaluation results saved to {filepath}") logger.info("Evaluation results saved to %s", filepath)
# Write samples to JSONL file if provided # Write samples to JSONL file if provided
if samples: if samples:
@ -789,7 +789,7 @@ class BaseEnv(ABC):
with jsonlines.open(samples_filepath, "w") as writer: with jsonlines.open(samples_filepath, "w") as writer:
for sample in samples: for sample in samples:
writer.write(sample) writer.write(sample)
print(f"Evaluation samples saved to {samples_filepath}") logger.info("Evaluation samples saved to %s", samples_filepath)
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
@ -823,7 +823,7 @@ class BaseEnv(ABC):
elif resp.status >= 400: elif resp.status >= 400:
logging.error(f"Client error: {resp.status}, not retrying") logging.error(f"Client error: {resp.status}, not retrying")
return return
print(await resp.text()) logger.debug(await resp.text())
def _post_json_with_compression( def _post_json_with_compression(
self, self,
@ -927,7 +927,9 @@ class BaseEnv(ABC):
if self.jsonl_writer is not None: if self.jsonl_writer is not None:
self.jsonl_writer.write(group) self.jsonl_writer.write(group)
print(f"Wrote scored group to {self.config.data_path_to_save_groups}") logger.info(
"Wrote scored group to %s", self.config.data_path_to_save_groups
)
valid_groups.append(group) valid_groups.append(group)
@ -948,7 +950,7 @@ class BaseEnv(ABC):
if isinstance(data_to_send_to_api, dict) if isinstance(data_to_send_to_api, dict)
else f"{len(data_to_send_to_api)} ScoredDataGroups" else f"{len(data_to_send_to_api)} ScoredDataGroups"
) )
print(f"Failed to send {data_type_str} after retries: {e}") logger.error("Failed to send %s after retries: %s", data_type_str, e)
async def handle_env( async def handle_env(
self, item_uuid: str self, item_uuid: str
@ -958,7 +960,7 @@ class BaseEnv(ABC):
""" """
item = self.running_items.get(item_uuid)["item"] item = self.running_items.get(item_uuid)["item"]
if item is None: if item is None:
print(f"item {item_uuid} not found... returning") logger.warning("item %s not found... returning", item_uuid)
return None return None
start_time = time.time() start_time = time.time()
logger.debug(f"handle_env: Starting with item: {item}") logger.debug(f"handle_env: Starting with item: {item}")
@ -979,7 +981,7 @@ class BaseEnv(ABC):
to_postprocess = await self.postprocess_histories(to_postprocess) to_postprocess = await self.postprocess_histories(to_postprocess)
except Exception as e: except Exception as e:
logger.error(f"Error in scoring: {item}") logger.error(f"Error in scoring: {item}")
print(e) logger.error("Scoring exception: %s", e)
to_postprocess = None to_postprocess = None
self.running_items.pop(item_uuid, None) self.running_items.pop(item_uuid, None)
duration = max(0.0, time.time() - start_time) duration = max(0.0, time.time() - start_time)
@ -1120,10 +1122,9 @@ class BaseEnv(ABC):
), ),
) )
max_num_workers = max(max_num_workers, min_workers_to_fill_self_queue) max_num_workers = max(max_num_workers, min_workers_to_fill_self_queue)
print( logger.info(
f"max_num_workers: {max_num_workers}, queue size: {self.status_dict['queue_size']}, " f"max_num_workers: {max_num_workers}, queue size: {self.status_dict['queue_size']}, "
f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}", f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}"
flush=True,
) )
if (self.curr_step == 0) and (len(self.workers) == 0): if (self.curr_step == 0) and (len(self.workers) == 0):
# We are starting up, so we should just skip the append to the list # We are starting up, so we should just skip the append to the list
@ -1131,10 +1132,9 @@ class BaseEnv(ABC):
else: else:
self.workers_added_list.append(max_num_workers - len(self.workers)) self.workers_added_list.append(max_num_workers - len(self.workers))
if len(self.workers) > max_num_workers: if len(self.workers) > max_num_workers:
print( logger.info(
f"len(self.workers) > max_num_workers: {len(self.workers)} > {max_num_workers}, " f"len(self.workers) > max_num_workers: {len(self.workers)} > {max_num_workers}, "
"sending workers to backlog", "sending workers to backlog"
flush=True,
) )
num_to_reduce = len(self.workers) - max_num_workers num_to_reduce = len(self.workers) - max_num_workers
running_items_to_remove = list(self.running_items.keys())[:num_to_reduce] running_items_to_remove = list(self.running_items.keys())[:num_to_reduce]
@ -1280,18 +1280,22 @@ class BaseEnv(ABC):
# Initialize the processing # Initialize the processing
self.curr_step = 0 self.curr_step = 0
print(f"Starting to process {self.n_groups_to_process} groups...") logger.info("Starting to process %s groups...", self.n_groups_to_process)
# Process the required number of groups # Process the required number of groups
while self.curr_step < self.n_groups_to_process: while self.curr_step < self.n_groups_to_process:
# Get an item to process # Get an item to process
item = await self.get_next_item() item = await self.get_next_item()
if item is None: if item is None:
print("No more items to process") logger.info("No more items to process")
break break
# Process the group # Process the group
print(f"Processing group {self.curr_step + 1}/{self.n_groups_to_process}") logger.info(
"Processing group %s/%s",
self.curr_step + 1,
self.n_groups_to_process,
)
# Collect trajectories with the specified group size # Collect trajectories with the specified group size
# Override the group_size temporarily # Override the group_size temporarily
@ -1314,13 +1318,13 @@ class BaseEnv(ABC):
await self.wandb_log() await self.wandb_log()
self.curr_step += 1 self.curr_step += 1
print( logger.info(
f"Successfully processed group {self.curr_step}/{self.n_groups_to_process}" f"Successfully processed group {self.curr_step}/{self.n_groups_to_process}"
) )
else: else:
print("Failed to process group, retrying...") logger.warning("Failed to process group, retrying...")
print(f"Completed processing {self.curr_step} groups") logger.info("Completed processing %s groups", self.curr_step)
# Close the output file if it's open # Close the output file if it's open
if self.jsonl_writer is not None: if self.jsonl_writer is not None:
@ -1354,8 +1358,7 @@ class BaseEnv(ABC):
"""Handles exceptions with clean output for known error types.""" """Handles exceptions with clean output for known error types."""
if isinstance(ex, FailedExecutionException): if isinstance(ex, FailedExecutionException):
# Handle argparse errors (already printed by argparse) # Handle argparse errors (already printed by argparse)
print() logger.error(ex.message.split("error: ")[-1])
print(ex.message.split("error: ")[-1])
return 2 return 2
raise ex raise ex
@ -1416,7 +1419,7 @@ class BaseEnv(ABC):
if self.config is not None: if self.config is not None:
with open(self.config, "r") as f: with open(self.config, "r") as f:
yaml_config = yaml.safe_load(f) yaml_config = yaml.safe_load(f)
print(f"Loaded config from {self.config}") logger.info("Loaded config from %s", self.config)
else: else:
yaml_config = {} yaml_config = {}
@ -1440,11 +1443,11 @@ class BaseEnv(ABC):
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
# Debug logging for CLI args # Debug logging for CLI args
print(f"[CLI DEBUG] cli_passed_flags = {cli_passed_flags}") logger.debug("[CLI DEBUG] cli_passed_flags = %s", cli_passed_flags)
print(f"[CLI DEBUG] openai_full_prefix = {openai_full_prefix}") logger.debug("[CLI DEBUG] openai_full_prefix = %s", openai_full_prefix)
print(f"[CLI DEBUG] oai_cli_passed_args = {oai_cli_passed_args}") logger.debug("[CLI DEBUG] oai_cli_passed_args = %s", oai_cli_passed_args)
print(f"[CLI DEBUG] yaml_oai_config = {yaml_oai_config}") logger.debug("[CLI DEBUG] yaml_oai_config = %s", yaml_oai_config)
# Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided # Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided
# This allows any environment to use --openai.* CLI args without modifying config_init # This allows any environment to use --openai.* CLI args without modifying config_init
# Use a new variable to avoid UnboundLocalError from closure scoping # Use a new variable to avoid UnboundLocalError from closure scoping
@ -1472,19 +1475,21 @@ class BaseEnv(ABC):
if isinstance(default_openai_config_, APIServerConfig) and isinstance( if isinstance(default_openai_config_, APIServerConfig) and isinstance(
yaml_oai_config, dict yaml_oai_config, dict
): ):
print( logger.debug(
f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}" "[CLI DEBUG] default_openai_config_.model_dump() = %s",
default_openai_config_.model_dump(),
) )
openai_config_dict = merge_dicts( openai_config_dict = merge_dicts(
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init) default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
yaml_oai_config, yaml_oai_config,
oai_cli_passed_args, oai_cli_passed_args,
) )
print( logger.debug(
f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}" "[CLI DEBUG] openai_config_dict after merge = %s",
openai_config_dict,
) )
else: else:
print( logger.debug(
"[CLI DEBUG] Not merging: default_openai_config_ " "[CLI DEBUG] Not merging: default_openai_config_ "
f"type={type(default_openai_config_)}, " f"type={type(default_openai_config_)}, "
f"yaml_oai_config type={type(yaml_oai_config)}" f"yaml_oai_config type={type(yaml_oai_config)}"
@ -1637,7 +1642,7 @@ class BaseEnv(ABC):
if self.config is not None: if self.config is not None:
with open(self.config, "r") as f: with open(self.config, "r") as f:
yaml_config = yaml.safe_load(f) yaml_config = yaml.safe_load(f)
print(f"Loaded config from {self.config}") logger.info("Loaded config from %s", self.config)
else: else:
yaml_config = {} yaml_config = {}
@ -1810,7 +1815,7 @@ class BaseEnv(ABC):
"data_path_to_save_groups must be set for process mode" "data_path_to_save_groups must be set for process mode"
) )
print( logger.info(
f"Processing {env_config.total_steps} groups of " f"Processing {env_config.total_steps} groups of "
f"{env_config.group_size} responses and " f"{env_config.group_size} responses and "
f"writing to {env_config.data_path_to_save_groups}" f"writing to {env_config.data_path_to_save_groups}"
@ -1906,7 +1911,7 @@ class BaseEnv(ABC):
if self.config is not None: if self.config is not None:
with open(self.config, "r") as f: with open(self.config, "r") as f:
yaml_config = yaml.safe_load(f) yaml_config = yaml.safe_load(f)
print(f"Loaded config from {self.config}") logger.info("Loaded config from %s", self.config)
else: else:
yaml_config = {} yaml_config = {}
@ -2092,7 +2097,7 @@ class BaseEnv(ABC):
yaml.dump( yaml.dump(
config_dict, f, default_flow_style=False, sort_keys=False config_dict, f, default_flow_style=False, sort_keys=False
) )
print(f"Dumped evaluate config to {config_filepath}") logger.info("Dumped evaluate config to %s", config_filepath)
# --- Create and Run Environment --- # --- Create and Run Environment ---
# Create the environment instance # Create the environment instance
@ -2103,7 +2108,7 @@ class BaseEnv(ABC):
testing=server_manager_config.testing, testing=server_manager_config.testing,
) )
print("Running evaluation...") logger.info("Running evaluation...")
# Handle the case where we might already be in an event loop # Handle the case where we might already be in an event loop
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -2112,6 +2117,6 @@ class BaseEnv(ABC):
except RuntimeError: except RuntimeError:
asyncio.run(env._run_evaluate()) asyncio.run(env._run_evaluate())
print("Evaluation completed.") logger.info("Evaluation completed.")
return CliEvaluateConfig return CliEvaluateConfig

View file

@ -8,6 +8,7 @@ This wrapper maintains a tree structure of sequences, where:
""" """
import os import os
import logging
import time import time
import uuid import uuid
import warnings import warnings
@ -23,6 +24,8 @@ from pydantic import BaseModel
from atroposlib.envs.server_handling.server_baseline import APIServer from atroposlib.envs.server_handling.server_baseline import APIServer
logger = logging.getLogger(__name__)
class SequenceNode(BaseModel): class SequenceNode(BaseModel):
""" """
@ -292,16 +295,14 @@ class ManagedServer:
if self._debug_requests_enabled(): if self._debug_requests_enabled():
msg_count = len(messages) msg_count = len(messages)
prompt_preview = prompt.replace("\n", "\\n")[:600] prompt_preview = prompt.replace("\n", "\\n")[:600]
print( logger.debug(
f"[ATROPOS_REQ_DEBUG] chat_completion messages={msg_count} " "[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s",
f"n={completion_kwargs.get('n')} max_tokens={completion_kwargs.get('max_tokens')} " msg_count,
f"temperature={completion_kwargs.get('temperature')}", completion_kwargs.get("n"),
flush=True, completion_kwargs.get("max_tokens"),
) completion_kwargs.get("temperature"),
print(
f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}",
flush=True,
) )
logger.debug("[ATROPOS_REQ_DEBUG] prompt_preview=%r", prompt_preview)
# Set model name if not provided # Set model name if not provided
if "model" not in completion_kwargs: if "model" not in completion_kwargs: