mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
logger changes
This commit is contained in:
parent
64d3ee1bd6
commit
35587cbdc0
3 changed files with 63 additions and 52 deletions
|
|
@ -286,7 +286,7 @@ class BaseEnv(ABC):
|
|||
counter += 1
|
||||
path_changed = True
|
||||
if path_changed:
|
||||
print(
|
||||
logger.info(
|
||||
f"Changed data path to {self.config.data_path_to_save_groups} because {original_path} already exists." # noqa: E501
|
||||
)
|
||||
|
||||
|
|
@ -364,7 +364,7 @@ class BaseEnv(ABC):
|
|||
to_postprocess["group_overrides"] = {}
|
||||
to_postprocess["overrides"] = []
|
||||
to_postprocess["images"] = []
|
||||
print("Processing results")
|
||||
logger.debug("Processing results")
|
||||
for result in results:
|
||||
to_postprocess["tokens"].append(result[0]["tokens"])
|
||||
to_postprocess["masks"].append(result[0]["masks"])
|
||||
|
|
@ -444,7 +444,7 @@ class BaseEnv(ABC):
|
|||
setattr(self, key, data[key])
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
print(f"Saving checkpoint at step {step} with data {data}")
|
||||
logger.info("Saving checkpoint at step %s with data %s", step, data)
|
||||
if data is None:
|
||||
# Don't have anything to save, abort
|
||||
return
|
||||
|
|
@ -542,7 +542,7 @@ class BaseEnv(ABC):
|
|||
self.config.total_steps = data["num_steps"]
|
||||
if self.config.total_steps == -1:
|
||||
raise ValueError("Total steps not set in config or server!")
|
||||
print(
|
||||
logger.info(
|
||||
f"Initialized env with id {self.env_id}: "
|
||||
f"curr_step: {self.curr_step}, "
|
||||
f"checkpoint_dir: {self.checkpoint_dir}, "
|
||||
|
|
@ -779,7 +779,7 @@ class BaseEnv(ABC):
|
|||
with open(filepath, "w") as f:
|
||||
json.dump(eval_result, f, indent=2)
|
||||
|
||||
print(f"Evaluation results saved to {filepath}")
|
||||
logger.info("Evaluation results saved to %s", filepath)
|
||||
|
||||
# Write samples to JSONL file if provided
|
||||
if samples:
|
||||
|
|
@ -789,7 +789,7 @@ class BaseEnv(ABC):
|
|||
with jsonlines.open(samples_filepath, "w") as writer:
|
||||
for sample in samples:
|
||||
writer.write(sample)
|
||||
print(f"Evaluation samples saved to {samples_filepath}")
|
||||
logger.info("Evaluation samples saved to %s", samples_filepath)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
|
|
@ -823,7 +823,7 @@ class BaseEnv(ABC):
|
|||
elif resp.status >= 400:
|
||||
logging.error(f"Client error: {resp.status}, not retrying")
|
||||
return
|
||||
print(await resp.text())
|
||||
logger.debug(await resp.text())
|
||||
|
||||
def _post_json_with_compression(
|
||||
self,
|
||||
|
|
@ -927,7 +927,9 @@ class BaseEnv(ABC):
|
|||
|
||||
if self.jsonl_writer is not None:
|
||||
self.jsonl_writer.write(group)
|
||||
print(f"Wrote scored group to {self.config.data_path_to_save_groups}")
|
||||
logger.info(
|
||||
"Wrote scored group to %s", self.config.data_path_to_save_groups
|
||||
)
|
||||
|
||||
valid_groups.append(group)
|
||||
|
||||
|
|
@ -948,7 +950,7 @@ class BaseEnv(ABC):
|
|||
if isinstance(data_to_send_to_api, dict)
|
||||
else f"{len(data_to_send_to_api)} ScoredDataGroups"
|
||||
)
|
||||
print(f"Failed to send {data_type_str} after retries: {e}")
|
||||
logger.error("Failed to send %s after retries: %s", data_type_str, e)
|
||||
|
||||
async def handle_env(
|
||||
self, item_uuid: str
|
||||
|
|
@ -958,7 +960,7 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
item = self.running_items.get(item_uuid)["item"]
|
||||
if item is None:
|
||||
print(f"item {item_uuid} not found... returning")
|
||||
logger.warning("item %s not found... returning", item_uuid)
|
||||
return None
|
||||
start_time = time.time()
|
||||
logger.debug(f"handle_env: Starting with item: {item}")
|
||||
|
|
@ -979,7 +981,7 @@ class BaseEnv(ABC):
|
|||
to_postprocess = await self.postprocess_histories(to_postprocess)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scoring: {item}")
|
||||
print(e)
|
||||
logger.error("Scoring exception: %s", e)
|
||||
to_postprocess = None
|
||||
self.running_items.pop(item_uuid, None)
|
||||
duration = max(0.0, time.time() - start_time)
|
||||
|
|
@ -1120,10 +1122,9 @@ class BaseEnv(ABC):
|
|||
),
|
||||
)
|
||||
max_num_workers = max(max_num_workers, min_workers_to_fill_self_queue)
|
||||
print(
|
||||
logger.info(
|
||||
f"max_num_workers: {max_num_workers}, queue size: {self.status_dict['queue_size']}, "
|
||||
f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}",
|
||||
flush=True,
|
||||
f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}"
|
||||
)
|
||||
if (self.curr_step == 0) and (len(self.workers) == 0):
|
||||
# We are starting up, so we should just skip the append to the list
|
||||
|
|
@ -1131,10 +1132,9 @@ class BaseEnv(ABC):
|
|||
else:
|
||||
self.workers_added_list.append(max_num_workers - len(self.workers))
|
||||
if len(self.workers) > max_num_workers:
|
||||
print(
|
||||
logger.info(
|
||||
f"len(self.workers) > max_num_workers: {len(self.workers)} > {max_num_workers}, "
|
||||
"sending workers to backlog",
|
||||
flush=True,
|
||||
"sending workers to backlog"
|
||||
)
|
||||
num_to_reduce = len(self.workers) - max_num_workers
|
||||
running_items_to_remove = list(self.running_items.keys())[:num_to_reduce]
|
||||
|
|
@ -1280,18 +1280,22 @@ class BaseEnv(ABC):
|
|||
# Initialize the processing
|
||||
self.curr_step = 0
|
||||
|
||||
print(f"Starting to process {self.n_groups_to_process} groups...")
|
||||
logger.info("Starting to process %s groups...", self.n_groups_to_process)
|
||||
|
||||
# Process the required number of groups
|
||||
while self.curr_step < self.n_groups_to_process:
|
||||
# Get an item to process
|
||||
item = await self.get_next_item()
|
||||
if item is None:
|
||||
print("No more items to process")
|
||||
logger.info("No more items to process")
|
||||
break
|
||||
|
||||
# Process the group
|
||||
print(f"Processing group {self.curr_step + 1}/{self.n_groups_to_process}")
|
||||
logger.info(
|
||||
"Processing group %s/%s",
|
||||
self.curr_step + 1,
|
||||
self.n_groups_to_process,
|
||||
)
|
||||
|
||||
# Collect trajectories with the specified group size
|
||||
# Override the group_size temporarily
|
||||
|
|
@ -1314,13 +1318,13 @@ class BaseEnv(ABC):
|
|||
await self.wandb_log()
|
||||
|
||||
self.curr_step += 1
|
||||
print(
|
||||
logger.info(
|
||||
f"Successfully processed group {self.curr_step}/{self.n_groups_to_process}"
|
||||
)
|
||||
else:
|
||||
print("Failed to process group, retrying...")
|
||||
logger.warning("Failed to process group, retrying...")
|
||||
|
||||
print(f"Completed processing {self.curr_step} groups")
|
||||
logger.info("Completed processing %s groups", self.curr_step)
|
||||
|
||||
# Close the output file if it's open
|
||||
if self.jsonl_writer is not None:
|
||||
|
|
@ -1354,8 +1358,7 @@ class BaseEnv(ABC):
|
|||
"""Handles exceptions with clean output for known error types."""
|
||||
if isinstance(ex, FailedExecutionException):
|
||||
# Handle argparse errors (already printed by argparse)
|
||||
print()
|
||||
print(ex.message.split("error: ")[-1])
|
||||
logger.error(ex.message.split("error: ")[-1])
|
||||
return 2
|
||||
|
||||
raise ex
|
||||
|
|
@ -1416,7 +1419,7 @@ class BaseEnv(ABC):
|
|||
if self.config is not None:
|
||||
with open(self.config, "r") as f:
|
||||
yaml_config = yaml.safe_load(f)
|
||||
print(f"Loaded config from {self.config}")
|
||||
logger.info("Loaded config from %s", self.config)
|
||||
else:
|
||||
yaml_config = {}
|
||||
|
||||
|
|
@ -1440,11 +1443,11 @@ class BaseEnv(ABC):
|
|||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||
|
||||
# Debug logging for CLI args
|
||||
print(f"[CLI DEBUG] cli_passed_flags = {cli_passed_flags}")
|
||||
print(f"[CLI DEBUG] openai_full_prefix = {openai_full_prefix}")
|
||||
print(f"[CLI DEBUG] oai_cli_passed_args = {oai_cli_passed_args}")
|
||||
print(f"[CLI DEBUG] yaml_oai_config = {yaml_oai_config}")
|
||||
|
||||
logger.debug("[CLI DEBUG] cli_passed_flags = %s", cli_passed_flags)
|
||||
logger.debug("[CLI DEBUG] openai_full_prefix = %s", openai_full_prefix)
|
||||
logger.debug("[CLI DEBUG] oai_cli_passed_args = %s", oai_cli_passed_args)
|
||||
logger.debug("[CLI DEBUG] yaml_oai_config = %s", yaml_oai_config)
|
||||
|
||||
# Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided
|
||||
# This allows any environment to use --openai.* CLI args without modifying config_init
|
||||
# Use a new variable to avoid UnboundLocalError from closure scoping
|
||||
|
|
@ -1472,19 +1475,21 @@ class BaseEnv(ABC):
|
|||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
print(
|
||||
f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}"
|
||||
logger.debug(
|
||||
"[CLI DEBUG] default_openai_config_.model_dump() = %s",
|
||||
default_openai_config_.model_dump(),
|
||||
)
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
)
|
||||
print(
|
||||
f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}"
|
||||
logger.debug(
|
||||
"[CLI DEBUG] openai_config_dict after merge = %s",
|
||||
openai_config_dict,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
logger.debug(
|
||||
"[CLI DEBUG] Not merging: default_openai_config_ "
|
||||
f"type={type(default_openai_config_)}, "
|
||||
f"yaml_oai_config type={type(yaml_oai_config)}"
|
||||
|
|
@ -1637,7 +1642,7 @@ class BaseEnv(ABC):
|
|||
if self.config is not None:
|
||||
with open(self.config, "r") as f:
|
||||
yaml_config = yaml.safe_load(f)
|
||||
print(f"Loaded config from {self.config}")
|
||||
logger.info("Loaded config from %s", self.config)
|
||||
else:
|
||||
yaml_config = {}
|
||||
|
||||
|
|
@ -1810,7 +1815,7 @@ class BaseEnv(ABC):
|
|||
"data_path_to_save_groups must be set for process mode"
|
||||
)
|
||||
|
||||
print(
|
||||
logger.info(
|
||||
f"Processing {env_config.total_steps} groups of "
|
||||
f"{env_config.group_size} responses and "
|
||||
f"writing to {env_config.data_path_to_save_groups}"
|
||||
|
|
@ -1906,7 +1911,7 @@ class BaseEnv(ABC):
|
|||
if self.config is not None:
|
||||
with open(self.config, "r") as f:
|
||||
yaml_config = yaml.safe_load(f)
|
||||
print(f"Loaded config from {self.config}")
|
||||
logger.info("Loaded config from %s", self.config)
|
||||
else:
|
||||
yaml_config = {}
|
||||
|
||||
|
|
@ -2092,7 +2097,7 @@ class BaseEnv(ABC):
|
|||
yaml.dump(
|
||||
config_dict, f, default_flow_style=False, sort_keys=False
|
||||
)
|
||||
print(f"Dumped evaluate config to {config_filepath}")
|
||||
logger.info("Dumped evaluate config to %s", config_filepath)
|
||||
|
||||
# --- Create and Run Environment ---
|
||||
# Create the environment instance
|
||||
|
|
@ -2103,7 +2108,7 @@ class BaseEnv(ABC):
|
|||
testing=server_manager_config.testing,
|
||||
)
|
||||
|
||||
print("Running evaluation...")
|
||||
logger.info("Running evaluation...")
|
||||
# Handle the case where we might already be in an event loop
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
|
@ -2112,6 +2117,6 @@ class BaseEnv(ABC):
|
|||
except RuntimeError:
|
||||
asyncio.run(env._run_evaluate())
|
||||
|
||||
print("Evaluation completed.")
|
||||
logger.info("Evaluation completed.")
|
||||
|
||||
return CliEvaluateConfig
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue