atropos/atroposlib/utils/config_handler.py
2025-05-09 09:54:20 -05:00

184 lines
6.8 KiB
Python

import argparse
import os
from typing import Any, Dict, Optional
import torch
import yaml
class ConfigHandler:
"""Handles loading and merging of configuration files with CLI overrides"""
def __init__(self, config_dir: Optional[str] = None):
self.config_dir = config_dir or os.path.join(
os.path.dirname(__file__), "../../configs"
)
self.parser = self._setup_argument_parser()
def _setup_argument_parser(self) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Training configuration")
# Config files
parser.add_argument(
"--env",
type=str,
default="crosswords",
help="Environment config file name (without .yaml)",
)
parser.add_argument(
"--agent",
type=str,
default="nous_hermes",
help="Agent config file name (without .yaml)",
)
parser.add_argument(
"--config",
type=str,
help="Configuration file name (without .yaml)",
)
# CLI overrides
parser.add_argument("--group-size", type=int, help="Override group size")
parser.add_argument("--total-steps", type=int, help="Override total steps")
parser.add_argument("--batch-size", type=int, help="Override batch size")
parser.add_argument("--seed", type=int, help="Override random seed")
parser.add_argument("--device", type=str, help="Override device (cuda/cpu/mps)")
parser.add_argument("--server-url", type=str, help="Override server URL")
# Dataset-specific overrides
parser.add_argument("--dataset-name", type=str, help="Override dataset name")
parser.add_argument("--dataset-split", type=str, help="Override dataset split")
parser.add_argument(
"--prompt-field", type=str, help="Override prompt field name"
)
parser.add_argument(
"--answer-field", type=str, help="Override answer field name"
)
parser.add_argument("--system-prompt", type=str, help="Override system prompt")
parser.add_argument(
"--max-generations", type=int, help="Override max generations per prompt"
)
parser.add_argument(
"--reward-funcs",
type=str,
nargs="+",
help="Override reward functions to use",
)
return parser
def _load_yaml(self, path: str) -> Dict[str, Any]:
"""Load a YAML configuration file"""
with open(path, "r") as f:
return yaml.safe_load(f)
def _determine_device(self, config: Dict[str, Any]) -> str:
if config.get("device") == "auto":
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda"
return "cpu"
return config.get("device", "cpu")
def load_config(self, args: Optional[argparse.Namespace] = None) -> Dict[str, Any]:
"""Load and merge configurations with CLI overrides"""
if args is None:
args = self.parser.parse_args()
# environment config
config = self._load_yaml(os.path.join(self.config_dir, f"envs/{args.env}.yaml"))
# agent/model config
agent_config = self._load_yaml(
os.path.join(self.config_dir, f"agents/{args.agent}.yaml")
)
config["agent"] = agent_config
# CLI overrides
if args.group_size:
config["group_size"] = args.group_size
if args.total_steps:
config["total_steps"] = args.total_steps
if args.batch_size:
config["batch_size"] = args.batch_size
if args.seed:
config["initial_seed"] = args.seed
if args.device:
config["agent"]["device"] = args.device
if args.server_url:
config["rollout_server_url"] = args.server_url
# Ensure player_names is populated based on group_size
if "env_kwargs" in config and "player_names" in config["env_kwargs"]:
config["env_kwargs"]["player_names"] = {
i: f"Player_{i}" for i in range(config["group_size"])
}
config["agent"]["device"] = self._determine_device(config["agent"])
return config
def load_dataset_config(
self, args: Optional[argparse.Namespace] = None
) -> Dict[str, Any]:
"""Load and merge dataset environment configurations with CLI overrides"""
if args is None:
args = self.parser.parse_args()
# Start with base environment config
config = self._load_yaml(os.path.join(self.config_dir, f"envs/{args.env}.yaml"))
# Load agent config
agent_config = self._load_yaml(
os.path.join(self.config_dir, f"agents/{args.agent}.yaml")
)
config["agent"] = agent_config
# Load dataset config if specified
if args.config:
dataset_config = self._load_yaml(
os.path.join(self.config_dir, f"datasets/{args.config}.yaml")
)
# Merge dataset config with main config instead of nesting
for key, value in dataset_config.items():
config[key] = value
# Apply CLI overrides for common parameters
if args.group_size:
config["group_size"] = args.group_size
if args.total_steps:
config["total_steps"] = args.total_steps
if args.batch_size:
config["batch_size"] = args.batch_size
if args.seed:
config["initial_seed"] = args.seed
if args.device:
config["agent"]["device"] = args.device
if args.server_url:
config["rollout_server_url"] = args.server_url
# Apply dataset-specific overrides
if "dataset" in config:
if args.dataset_name:
config["dataset"]["dataset_name"] = args.dataset_name
if args.dataset_split:
config["dataset"]["split"] = args.dataset_split
if args.prompt_field:
config["dataset"]["prompt_field"] = args.prompt_field
if args.answer_field:
config["dataset"]["answer_field"] = args.answer_field
if args.system_prompt:
config["dataset"]["system_prompt"] = args.system_prompt
if args.max_generations:
config["dataset"]["max_generations_per_prompt"] = args.max_generations
if args.reward_funcs:
config["dataset"]["reward_funcs"] = args.reward_funcs
# Set device
config["agent"]["device"] = self._determine_device(config["agent"])
# Add slurm flag to config if running in a Slurm environment
config["use_slurm"] = "SLURM_JOB_ID" in os.environ
return config