mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
156 lines
7.1 KiB
Python
156 lines
7.1 KiB
Python
from pathlib import Path
|
|
from typing import Any, Literal, Optional
|
|
from .utils import AttributeDict, EndlessList, convert_type_inside_dict, maybe_abspath, some, wrap_list
|
|
|
|
|
|
class BaseConfig(AttributeDict):
|
|
def __init__(self, config: dict):
|
|
super().__init__(convert_type_inside_dict(config, dict, AttributeDict))
|
|
|
|
|
|
class NamedConfig(BaseConfig):
|
|
def __init__(
|
|
self,
|
|
config: dict[str, Any],
|
|
identifier_key: str = "name",
|
|
outdir_key: str = "output_dir_name"
|
|
) -> None:
|
|
super().__init__(config)
|
|
self.name = config[identifier_key]
|
|
self.output_dir_name = config.get(outdir_key, self.name)
|
|
|
|
|
|
class OptimizerConfig(NamedConfig):
|
|
def __init__(
|
|
self,
|
|
config: dict[str, Any],
|
|
optimizer_key: str,
|
|
task_key: str,
|
|
identifier_key: str = "name",
|
|
outdir_key: str = "output_dir_name"
|
|
) -> None:
|
|
cfg = dict(config[optimizer_key])
|
|
self.lr_interval: Literal["step", "epoch"] = cfg.get("lr_interval", "step")
|
|
self.max_steps: int = config[task_key].get("max_steps", None)
|
|
self.max_epochs: int = config[task_key]["max_epochs"]
|
|
cfg["max_steps"] = self.max_steps
|
|
cfg["max_epochs"] = self.max_epochs
|
|
super().__init__(cfg, identifier_key, outdir_key)
|
|
|
|
|
|
class TaskConfig(NamedConfig):
|
|
def __init__(
|
|
self,
|
|
config: dict[str, Any],
|
|
task_key: str,
|
|
engine_key: str,
|
|
identifier_key: str = "name",
|
|
outdir_key: str = "output_dir_name"
|
|
) -> None:
|
|
cfg = dict(config[task_key])
|
|
self.batch_size: int = cfg["batch_size"]
|
|
self.data_dir = Path(config[engine_key]["data_dir"]).resolve()
|
|
self.max_epochs: int = cfg["max_epochs"]
|
|
self.max_steps: int = cfg.get("max_steps", None)
|
|
self.target_metric: str = cfg["target_metric"]
|
|
self.target_metric_mode: str = cfg["target_metric_mode"]
|
|
self.workers = config[engine_key]["workers"]
|
|
cfg["data_dir"] = self.data_dir
|
|
cfg["workers"] = self.workers
|
|
super().__init__(cfg, identifier_key, outdir_key)
|
|
|
|
|
|
class EngineConfig(BaseConfig):
|
|
def __init__(self, config: dict[str, Any], task_key: str, engine_key: str) -> None:
|
|
cfg = dict(config[engine_key])
|
|
self.accelerator = cfg["accelerator"]
|
|
self.deterministic: bool | Literal["warn"] = cfg["deterministic"]
|
|
self.data_dir = Path(cfg["data_dir"]).resolve()
|
|
self.detect_anomaly: bool = cfg["detect_anomaly"]
|
|
self.devices: int = some(cfg["devices"], default=1)
|
|
self.early_stopping: Optional[int] = cfg["early_stopping"]
|
|
self.early_stopping_metric: str = some(cfg["early_stopping_metric"], default=config[task_key]["target_metric"])
|
|
self.gradient_clip_alg: str = cfg["gradient_clip_alg"]
|
|
self.gradient_clip_val: Optional[float] = cfg["gradient_clip_val"]
|
|
self.log_extra: bool | dict[str, bool] = cfg["log_extra"]
|
|
self.logging_inteval: int = cfg["logging_interval"]
|
|
self.max_steps: int = config[task_key].get("max_steps", None)
|
|
self.optimize_memory: bool = cfg["optimize_memory"]
|
|
self.output_dir = Path(cfg["output_dir"]).resolve()
|
|
self.plot: bool = cfg["plot"]
|
|
self.precision: str = cfg["precision"]
|
|
self.restrict_train_epochs: Optional[int] = cfg["restrict_train_epochs"]
|
|
_resume = cfg.get("resume", False)
|
|
self.resume: Optional[Path] | bool = Path(_resume).resolve() if isinstance(_resume, str) else _resume
|
|
self.run_scheduler: str = cfg["run_scheduler"]
|
|
self.seed: int = cfg["seed"]
|
|
self.seed_mode: str = cfg["seed_mode"]
|
|
self.save_sbatch_scripts: Optional[Path] = maybe_abspath(cfg["save_sbatch_scripts"])
|
|
self.sbatch_args: dict[str, str] = cfg["sbatch_args"]
|
|
self.sbatch_script_template: Optional[Path] = maybe_abspath(cfg["sbatch_script_template"])
|
|
self.sbatch_time_factor: float = cfg["sbatch_time_factor"]
|
|
self.slurm_log_dir: Optional[Path] = maybe_abspath(cfg["slurm_log_dir"])
|
|
self.silent: bool = cfg.get("silent", False)
|
|
self.test: bool = cfg.get("test", True)
|
|
self.train: bool = cfg.get("train", True)
|
|
self.validate: bool = cfg.get("validate", False)
|
|
self.workers: int = cfg["workers"]
|
|
cfg["data_dir"] = self.data_dir
|
|
cfg["devices"] = self.devices
|
|
cfg["early_stopping_metric"] = self.early_stopping_metric
|
|
cfg["max_steps"] = self.max_steps
|
|
cfg["output_dir"] = self.output_dir
|
|
cfg["resume"] = self.resume
|
|
cfg["slurm_log_dir"] = self.slurm_log_dir
|
|
cfg["save_sbatch_scripts"] = self.save_sbatch_scripts
|
|
cfg["sbatch_script_template"] = self.sbatch_script_template
|
|
super().__init__(cfg)
|
|
|
|
def outpath_relevant_engine_keys(self, prefix: str = "") -> list[str]:
|
|
keys = [
|
|
"accelerator",
|
|
"deterministic",
|
|
"detect_anomaly",
|
|
"devices",
|
|
"early_stopping",
|
|
"gradient_clip_alg",
|
|
"gradient_clip_val",
|
|
"optimize_memory",
|
|
"precision",
|
|
"seed"
|
|
]
|
|
return [f"{prefix}{k}" for k in keys]
|
|
|
|
def outpath_irrelevant_engine_keys(self, prefix: str = "") -> list[str]:
|
|
return [f"{prefix}{k}" for k in self.keys() if k not in self.outpath_relevant_engine_keys()]
|
|
|
|
|
|
class EvalConfig(BaseConfig):
|
|
def __init__(self, config: dict[str, Any], eval_key: str, engine_key: str, ignore_keys = None) -> None:
|
|
cfg = dict(config[eval_key])
|
|
self.experiment_files = AttributeDict(dict(
|
|
best_model = "results_best_model.json",
|
|
last_model = "results_final_model.json",
|
|
config = "config.yaml"
|
|
))
|
|
self.output_types: list[str] = wrap_list(cfg["output_types"])
|
|
experiment_dir = Path(config[engine_key]["output_dir"]).resolve()
|
|
self.output_dir: Path = some(maybe_abspath(cfg["output_dir"]), default=experiment_dir / "plots")
|
|
self.experiment_name: str = cfg["experiment_name"]
|
|
self.verbose: bool = cfg.get("verbose", False)
|
|
split = cfg.get("split_groups", False)
|
|
self.split_groups: bool | list[str] = split if isinstance(split, bool) else wrap_list(split)
|
|
self.checkpoints: list[Literal["last", "best"]] = wrap_list(cfg["checkpoints"])
|
|
self.column_split_key: Optional[str] = cfg.get("column_split_key", None)
|
|
self.column_split_order: Optional[list[str]] = cfg.get("column_split_order", None)
|
|
self.ignore_keys: list[str] = some(ignore_keys, default=[])
|
|
self.aggregate_groups: list[str] = wrap_list(cfg["aggregate_groups"])
|
|
cfg["ignore_keys"] = self.ignore_keys
|
|
cfg["output_types"] = self.output_types
|
|
cfg["output_dir"] = self.output_dir
|
|
cfg["aggregate_groups"] = self.aggregate_groups
|
|
cfg["output_types"] = self.output_types
|
|
cfg["plot"]["x_axis"] = EndlessList(wrap_list(cfg["plot"]["x_axis"]))
|
|
cfg["plot"]["y_axis"] = EndlessList(wrap_list(cfg["plot"]["y_axis"]))
|
|
cfg["split_groups"] = self.split_groups
|
|
super().__init__(cfg)
|