mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
184 lines
6.8 KiB
Python
184 lines
6.8 KiB
Python
import argparse
|
|
import os
|
|
from typing import Any, Dict, Optional
|
|
|
|
import torch
|
|
import yaml
|
|
|
|
|
|
class ConfigHandler:
|
|
"""Handles loading and merging of configuration files with CLI overrides"""
|
|
|
|
def __init__(self, config_dir: Optional[str] = None):
|
|
self.config_dir = config_dir or os.path.join(
|
|
os.path.dirname(__file__), "../../configs"
|
|
)
|
|
self.parser = self._setup_argument_parser()
|
|
|
|
def _setup_argument_parser(self) -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(description="Training configuration")
|
|
|
|
# Config files
|
|
parser.add_argument(
|
|
"--env",
|
|
type=str,
|
|
default="crosswords",
|
|
help="Environment config file name (without .yaml)",
|
|
)
|
|
parser.add_argument(
|
|
"--agent",
|
|
type=str,
|
|
default="nous_hermes",
|
|
help="Agent config file name (without .yaml)",
|
|
)
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
help="Configuration file name (without .yaml)",
|
|
)
|
|
|
|
# CLI overrides
|
|
parser.add_argument("--group-size", type=int, help="Override group size")
|
|
parser.add_argument("--total-steps", type=int, help="Override total steps")
|
|
parser.add_argument("--batch-size", type=int, help="Override batch size")
|
|
parser.add_argument("--seed", type=int, help="Override random seed")
|
|
parser.add_argument("--device", type=str, help="Override device (cuda/cpu/mps)")
|
|
parser.add_argument("--server-url", type=str, help="Override server URL")
|
|
|
|
# Dataset-specific overrides
|
|
parser.add_argument("--dataset-name", type=str, help="Override dataset name")
|
|
parser.add_argument("--dataset-split", type=str, help="Override dataset split")
|
|
parser.add_argument(
|
|
"--prompt-field", type=str, help="Override prompt field name"
|
|
)
|
|
parser.add_argument(
|
|
"--answer-field", type=str, help="Override answer field name"
|
|
)
|
|
parser.add_argument("--system-prompt", type=str, help="Override system prompt")
|
|
parser.add_argument(
|
|
"--max-generations", type=int, help="Override max generations per prompt"
|
|
)
|
|
parser.add_argument(
|
|
"--reward-funcs",
|
|
type=str,
|
|
nargs="+",
|
|
help="Override reward functions to use",
|
|
)
|
|
|
|
return parser
|
|
|
|
def _load_yaml(self, path: str) -> Dict[str, Any]:
|
|
"""Load a YAML configuration file"""
|
|
with open(path, "r") as f:
|
|
return yaml.safe_load(f)
|
|
|
|
def _determine_device(self, config: Dict[str, Any]) -> str:
|
|
if config.get("device") == "auto":
|
|
if torch.backends.mps.is_available():
|
|
return "mps"
|
|
elif torch.cuda.is_available():
|
|
return "cuda"
|
|
return "cpu"
|
|
return config.get("device", "cpu")
|
|
|
|
def load_config(self, args: Optional[argparse.Namespace] = None) -> Dict[str, Any]:
|
|
"""Load and merge configurations with CLI overrides"""
|
|
if args is None:
|
|
args = self.parser.parse_args()
|
|
|
|
# environment config
|
|
config = self._load_yaml(os.path.join(self.config_dir, f"envs/{args.env}.yaml"))
|
|
|
|
# agent/model config
|
|
agent_config = self._load_yaml(
|
|
os.path.join(self.config_dir, f"agents/{args.agent}.yaml")
|
|
)
|
|
config["agent"] = agent_config
|
|
|
|
# CLI overrides
|
|
if args.group_size:
|
|
config["group_size"] = args.group_size
|
|
if args.total_steps:
|
|
config["total_steps"] = args.total_steps
|
|
if args.batch_size:
|
|
config["batch_size"] = args.batch_size
|
|
if args.seed:
|
|
config["initial_seed"] = args.seed
|
|
if args.device:
|
|
config["agent"]["device"] = args.device
|
|
if args.server_url:
|
|
config["rollout_server_url"] = args.server_url
|
|
|
|
# Ensure player_names is populated based on group_size
|
|
if "env_kwargs" in config and "player_names" in config["env_kwargs"]:
|
|
config["env_kwargs"]["player_names"] = {
|
|
i: f"Player_{i}" for i in range(config["group_size"])
|
|
}
|
|
|
|
config["agent"]["device"] = self._determine_device(config["agent"])
|
|
|
|
return config
|
|
|
|
def load_dataset_config(
|
|
self, args: Optional[argparse.Namespace] = None
|
|
) -> Dict[str, Any]:
|
|
"""Load and merge dataset environment configurations with CLI overrides"""
|
|
if args is None:
|
|
args = self.parser.parse_args()
|
|
|
|
# Start with base environment config
|
|
config = self._load_yaml(os.path.join(self.config_dir, f"envs/{args.env}.yaml"))
|
|
|
|
# Load agent config
|
|
agent_config = self._load_yaml(
|
|
os.path.join(self.config_dir, f"agents/{args.agent}.yaml")
|
|
)
|
|
config["agent"] = agent_config
|
|
|
|
# Load dataset config if specified
|
|
if args.config:
|
|
dataset_config = self._load_yaml(
|
|
os.path.join(self.config_dir, f"datasets/{args.config}.yaml")
|
|
)
|
|
# Merge dataset config with main config instead of nesting
|
|
for key, value in dataset_config.items():
|
|
config[key] = value
|
|
|
|
# Apply CLI overrides for common parameters
|
|
if args.group_size:
|
|
config["group_size"] = args.group_size
|
|
if args.total_steps:
|
|
config["total_steps"] = args.total_steps
|
|
if args.batch_size:
|
|
config["batch_size"] = args.batch_size
|
|
if args.seed:
|
|
config["initial_seed"] = args.seed
|
|
if args.device:
|
|
config["agent"]["device"] = args.device
|
|
if args.server_url:
|
|
config["rollout_server_url"] = args.server_url
|
|
|
|
# Apply dataset-specific overrides
|
|
if "dataset" in config:
|
|
if args.dataset_name:
|
|
config["dataset"]["dataset_name"] = args.dataset_name
|
|
if args.dataset_split:
|
|
config["dataset"]["split"] = args.dataset_split
|
|
if args.prompt_field:
|
|
config["dataset"]["prompt_field"] = args.prompt_field
|
|
if args.answer_field:
|
|
config["dataset"]["answer_field"] = args.answer_field
|
|
if args.system_prompt:
|
|
config["dataset"]["system_prompt"] = args.system_prompt
|
|
if args.max_generations:
|
|
config["dataset"]["max_generations_per_prompt"] = args.max_generations
|
|
if args.reward_funcs:
|
|
config["dataset"]["reward_funcs"] = args.reward_funcs
|
|
|
|
# Set device
|
|
config["agent"]["device"] = self._determine_device(config["agent"])
|
|
|
|
# Add slurm flag to config if running in a Slurm environment
|
|
config["use_slurm"] = "SLURM_JOB_ID" in os.environ
|
|
|
|
return config
|