mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
248 lines
8.8 KiB
Python
248 lines
8.8 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from atroposlib.envs.base import OpenaiConfig
|
|
from atroposlib.envs.reward_fns import registry
|
|
from atroposlib.utils.config_handler import ConfigHandler
|
|
from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig
|
|
|
|
load_dotenv()
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(description="Dataset environment local server")
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
default="dataset_local",
|
|
help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/, or full path to a YAML file.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
async def main():
|
|
logger.info("Starting Dataset environment local server")
|
|
|
|
# Parse command line arguments
|
|
args = parse_arguments()
|
|
|
|
# Initialize config handler
|
|
config_handler = ConfigHandler()
|
|
|
|
# Determine config path
|
|
if (
|
|
os.path.isabs(args.config)
|
|
or "/" in args.config
|
|
or args.config.endswith(".yaml")
|
|
):
|
|
config_path = args.config
|
|
else:
|
|
# Assume it's a name relative to the new default directory
|
|
config_path = os.path.join(
|
|
os.path.dirname(__file__), "configs", f"{args.config}.yaml"
|
|
)
|
|
|
|
logger.info(f"Loading configuration from: {config_path}")
|
|
|
|
try:
|
|
with open(config_path, "r") as f:
|
|
import yaml
|
|
|
|
raw_config = yaml.safe_load(f)
|
|
logger.info("Loaded configuration successfully")
|
|
except FileNotFoundError:
|
|
logger.error(f"Configuration file not found at: {config_path}")
|
|
logger.info("Ensure the --config argument is correct or the file exists.")
|
|
return
|
|
except Exception as e:
|
|
logger.error(f"Error loading config from {config_path}: {e}")
|
|
return
|
|
|
|
# Ensure dataset configuration exists (assuming it's top-level in these files)
|
|
if "dataset" not in raw_config:
|
|
logger.warning(
|
|
"'dataset' key not found at the top level of the config file. "
|
|
"Assuming the entire file is the dataset configuration."
|
|
)
|
|
# Treat the whole raw_config as the 'dataset' section for compatibility
|
|
dataset_section = raw_config
|
|
else:
|
|
dataset_section = raw_config["dataset"]
|
|
|
|
if "dataset_name" not in dataset_section:
|
|
logger.error("dataset_name not found in dataset configuration")
|
|
return
|
|
if "prompt_field" not in dataset_section:
|
|
logger.error("prompt_field not found in dataset configuration")
|
|
return
|
|
|
|
# Configure the dataset environment
|
|
# Merging logic: Start with raw_config defaults, then dataset_section specifics
|
|
env_config_data = {**raw_config, **dataset_section}
|
|
|
|
# Pydantic will ignore extra fields, so we just pass everything
|
|
try:
|
|
env_config = DatasetEnvConfig(**env_config_data)
|
|
except Exception as pydantic_error:
|
|
logger.error(f"Error validating configuration: {pydantic_error}")
|
|
return
|
|
|
|
# Preload reward functions
|
|
reward_names_to_load = set()
|
|
if env_config.reward_funcs:
|
|
reward_names_to_load.update(env_config.reward_funcs)
|
|
if env_config.reward_functions:
|
|
for rf_config in env_config.reward_functions:
|
|
if isinstance(rf_config, str):
|
|
reward_names_to_load.add(rf_config)
|
|
elif isinstance(rf_config, dict) and "type" in rf_config:
|
|
reward_names_to_load.add(rf_config["type"])
|
|
|
|
if reward_names_to_load:
|
|
logger.info(f"Preloading reward functions: {list(reward_names_to_load)}")
|
|
for func_name in reward_names_to_load:
|
|
try:
|
|
registry.get(func_name)
|
|
logger.info(f"Successfully loaded reward function: {func_name}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load reward function {func_name}: {e}")
|
|
|
|
# Server configuration - process env vars
|
|
server_configs = []
|
|
|
|
if "server_configs" in raw_config:
|
|
for server_config in raw_config["server_configs"]:
|
|
api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY"))
|
|
# Handle environment variable references like ${OPENAI_API_KEY}
|
|
if (
|
|
isinstance(api_key, str)
|
|
and api_key.startswith("${")
|
|
and api_key.endswith("}")
|
|
):
|
|
env_var = api_key[2:-1]
|
|
api_key = os.environ.get(env_var, "")
|
|
|
|
server_configs.append(
|
|
OpenaiConfig(
|
|
model_name=server_config.get("model_name", "gpt-4.1-nano"),
|
|
base_url=server_config.get("base_url", None),
|
|
api_key=api_key,
|
|
timeout=server_config.get("timeout", 600),
|
|
)
|
|
)
|
|
else:
|
|
# Default configuration if not specified in config file
|
|
logger.warning(
|
|
"No 'server_configs' found in config. Using default OpenAI config."
|
|
)
|
|
server_configs.append(
|
|
OpenaiConfig(
|
|
model_name="gpt-4.1-nano",
|
|
base_url=None,
|
|
api_key=os.environ.get("OPENAI_API_KEY"),
|
|
timeout=600,
|
|
)
|
|
)
|
|
|
|
# Create the environment
|
|
logger.info("Creating dataset environment...")
|
|
env = DatasetEnv(
|
|
config=env_config,
|
|
server_configs=server_configs,
|
|
slurm=False,
|
|
)
|
|
|
|
# Setup the environment directly
|
|
try:
|
|
await env.setup()
|
|
logger.info("Environment setup complete")
|
|
except Exception as setup_error:
|
|
logger.error(f"Error during environment setup: {setup_error}")
|
|
return
|
|
|
|
# --- Start Test Run --- #
|
|
logger.info("\n=== Starting Local Test Run ===")
|
|
test_items_count = 5
|
|
successful_runs = 0
|
|
|
|
for i in range(test_items_count):
|
|
logger.info(f"\n--- Running Test Item {i+1}/{test_items_count} ---")
|
|
try:
|
|
# Get a sample item from the dataset
|
|
item = await env.get_next_item()
|
|
if not item or not item[0]:
|
|
logger.warning("Failed to get a valid item from the environment.")
|
|
continue
|
|
|
|
prompt, answer, ground_truth = item
|
|
user_content = dict(prompt[0])["content"]
|
|
logger.info(
|
|
f"Prompt: {user_content[:200]}..." if user_content else "(Empty Prompt)"
|
|
)
|
|
if answer:
|
|
logger.info(
|
|
f"Answer: {answer[:200]}..." if answer else "(Empty Answer)"
|
|
)
|
|
if ground_truth:
|
|
logger.info(
|
|
f"Ground Truth: {ground_truth[:200]}..."
|
|
if ground_truth
|
|
else "(Empty Ground Truth)"
|
|
)
|
|
|
|
# Collect trajectories (using group_size from config)
|
|
logger.info(
|
|
f"Collecting {env_config.group_size} trajectories for this item..."
|
|
)
|
|
trajectories_data, backlog = await env.collect_trajectories(item)
|
|
|
|
if not trajectories_data:
|
|
logger.warning("No trajectories were collected.")
|
|
continue
|
|
|
|
logger.info(f"Collected {len(trajectories_data)} trajectories.")
|
|
# Log first trajectory message content for inspection
|
|
if trajectories_data[0] and isinstance(trajectories_data[0], list):
|
|
first_response = "(Empty or invalid trajectory format)"
|
|
assistant_msgs = [
|
|
m
|
|
for m in trajectories_data[0]
|
|
if isinstance(m, dict) and m.get("role") == "assistant"
|
|
]
|
|
if assistant_msgs:
|
|
first_response = assistant_msgs[-1].get("content", "(No content)")
|
|
logger.info(f"First Response Content: {first_response[:300]}...")
|
|
|
|
# Score the collected trajectories
|
|
logger.info("Scoring trajectories...")
|
|
scored_data = await env.score(trajectories_data)
|
|
|
|
# Print scores
|
|
if scored_data and "scores" in scored_data:
|
|
scores_list = scored_data["scores"]
|
|
logger.info(f"Scores: {scores_list}")
|
|
logger.info(f" Avg Score: {sum(scores_list)/len(scores_list):.4f}")
|
|
successful_runs += 1
|
|
else:
|
|
logger.warning("No scores available in the scored data for this item.")
|
|
|
|
except Exception as run_error:
|
|
logger.error(f"Error during test item {i+1}: {run_error}")
|
|
# Optionally continue to the next item or break
|
|
# break
|
|
|
|
logger.info(
|
|
f"\n=== Local Test Run Complete ({successful_runs}/{test_items_count} items processed successfully) ==="
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|