atropos/environments/optimizer/FOB/pytorch_fob/engine/configs.py
2025-05-18 16:36:28 -07:00

156 lines
7.1 KiB
Python

from pathlib import Path
from typing import Any, Literal, Optional
from .utils import AttributeDict, EndlessList, convert_type_inside_dict, maybe_abspath, some, wrap_list
class BaseConfig(AttributeDict):
def __init__(self, config: dict):
super().__init__(convert_type_inside_dict(config, dict, AttributeDict))
class NamedConfig(BaseConfig):
def __init__(
self,
config: dict[str, Any],
identifier_key: str = "name",
outdir_key: str = "output_dir_name"
) -> None:
super().__init__(config)
self.name = config[identifier_key]
self.output_dir_name = config.get(outdir_key, self.name)
class OptimizerConfig(NamedConfig):
def __init__(
self,
config: dict[str, Any],
optimizer_key: str,
task_key: str,
identifier_key: str = "name",
outdir_key: str = "output_dir_name"
) -> None:
cfg = dict(config[optimizer_key])
self.lr_interval: Literal["step", "epoch"] = cfg.get("lr_interval", "step")
self.max_steps: int = config[task_key].get("max_steps", None)
self.max_epochs: int = config[task_key]["max_epochs"]
cfg["max_steps"] = self.max_steps
cfg["max_epochs"] = self.max_epochs
super().__init__(cfg, identifier_key, outdir_key)
class TaskConfig(NamedConfig):
def __init__(
self,
config: dict[str, Any],
task_key: str,
engine_key: str,
identifier_key: str = "name",
outdir_key: str = "output_dir_name"
) -> None:
cfg = dict(config[task_key])
self.batch_size: int = cfg["batch_size"]
self.data_dir = Path(config[engine_key]["data_dir"]).resolve()
self.max_epochs: int = cfg["max_epochs"]
self.max_steps: int = cfg.get("max_steps", None)
self.target_metric: str = cfg["target_metric"]
self.target_metric_mode: str = cfg["target_metric_mode"]
self.workers = config[engine_key]["workers"]
cfg["data_dir"] = self.data_dir
cfg["workers"] = self.workers
super().__init__(cfg, identifier_key, outdir_key)
class EngineConfig(BaseConfig):
def __init__(self, config: dict[str, Any], task_key: str, engine_key: str) -> None:
cfg = dict(config[engine_key])
self.accelerator = cfg["accelerator"]
self.deterministic: bool | Literal["warn"] = cfg["deterministic"]
self.data_dir = Path(cfg["data_dir"]).resolve()
self.detect_anomaly: bool = cfg["detect_anomaly"]
self.devices: int = some(cfg["devices"], default=1)
self.early_stopping: Optional[int] = cfg["early_stopping"]
self.early_stopping_metric: str = some(cfg["early_stopping_metric"], default=config[task_key]["target_metric"])
self.gradient_clip_alg: str = cfg["gradient_clip_alg"]
self.gradient_clip_val: Optional[float] = cfg["gradient_clip_val"]
self.log_extra: bool | dict[str, bool] = cfg["log_extra"]
self.logging_inteval: int = cfg["logging_interval"]
self.max_steps: int = config[task_key].get("max_steps", None)
self.optimize_memory: bool = cfg["optimize_memory"]
self.output_dir = Path(cfg["output_dir"]).resolve()
self.plot: bool = cfg["plot"]
self.precision: str = cfg["precision"]
self.restrict_train_epochs: Optional[int] = cfg["restrict_train_epochs"]
_resume = cfg.get("resume", False)
self.resume: Optional[Path] | bool = Path(_resume).resolve() if isinstance(_resume, str) else _resume
self.run_scheduler: str = cfg["run_scheduler"]
self.seed: int = cfg["seed"]
self.seed_mode: str = cfg["seed_mode"]
self.save_sbatch_scripts: Optional[Path] = maybe_abspath(cfg["save_sbatch_scripts"])
self.sbatch_args: dict[str, str] = cfg["sbatch_args"]
self.sbatch_script_template: Optional[Path] = maybe_abspath(cfg["sbatch_script_template"])
self.sbatch_time_factor: float = cfg["sbatch_time_factor"]
self.slurm_log_dir: Optional[Path] = maybe_abspath(cfg["slurm_log_dir"])
self.silent: bool = cfg.get("silent", False)
self.test: bool = cfg.get("test", True)
self.train: bool = cfg.get("train", True)
self.validate: bool = cfg.get("validate", False)
self.workers: int = cfg["workers"]
cfg["data_dir"] = self.data_dir
cfg["devices"] = self.devices
cfg["early_stopping_metric"] = self.early_stopping_metric
cfg["max_steps"] = self.max_steps
cfg["output_dir"] = self.output_dir
cfg["resume"] = self.resume
cfg["slurm_log_dir"] = self.slurm_log_dir
cfg["save_sbatch_scripts"] = self.save_sbatch_scripts
cfg["sbatch_script_template"] = self.sbatch_script_template
super().__init__(cfg)
def outpath_relevant_engine_keys(self, prefix: str = "") -> list[str]:
keys = [
"accelerator",
"deterministic",
"detect_anomaly",
"devices",
"early_stopping",
"gradient_clip_alg",
"gradient_clip_val",
"optimize_memory",
"precision",
"seed"
]
return [f"{prefix}{k}" for k in keys]
def outpath_irrelevant_engine_keys(self, prefix: str = "") -> list[str]:
return [f"{prefix}{k}" for k in self.keys() if k not in self.outpath_relevant_engine_keys()]
class EvalConfig(BaseConfig):
def __init__(self, config: dict[str, Any], eval_key: str, engine_key: str, ignore_keys = None) -> None:
cfg = dict(config[eval_key])
self.experiment_files = AttributeDict(dict(
best_model = "results_best_model.json",
last_model = "results_final_model.json",
config = "config.yaml"
))
self.output_types: list[str] = wrap_list(cfg["output_types"])
experiment_dir = Path(config[engine_key]["output_dir"]).resolve()
self.output_dir: Path = some(maybe_abspath(cfg["output_dir"]), default=experiment_dir / "plots")
self.experiment_name: str = cfg["experiment_name"]
self.verbose: bool = cfg.get("verbose", False)
split = cfg.get("split_groups", False)
self.split_groups: bool | list[str] = split if isinstance(split, bool) else wrap_list(split)
self.checkpoints: list[Literal["last", "best"]] = wrap_list(cfg["checkpoints"])
self.column_split_key: Optional[str] = cfg.get("column_split_key", None)
self.column_split_order: Optional[list[str]] = cfg.get("column_split_order", None)
self.ignore_keys: list[str] = some(ignore_keys, default=[])
self.aggregate_groups: list[str] = wrap_list(cfg["aggregate_groups"])
cfg["ignore_keys"] = self.ignore_keys
cfg["output_types"] = self.output_types
cfg["output_dir"] = self.output_dir
cfg["aggregate_groups"] = self.aggregate_groups
cfg["output_types"] = self.output_types
cfg["plot"]["x_axis"] = EndlessList(wrap_list(cfg["plot"]["x_axis"]))
cfg["plot"]["y_axis"] = EndlessList(wrap_list(cfg["plot"]["y_axis"]))
cfg["split_groups"] = self.split_groups
super().__init__(cfg)