atropos/environments/optimizer/FOB/pytorch_fob/engine/engine.py
2025-05-18 16:36:28 -07:00

228 lines
10 KiB
Python

import json
from copy import deepcopy
from typing import Any, Callable, Iterable, Iterator, Literal, Optional
from pathlib import Path
from matplotlib.figure import Figure
from pandas import DataFrame, concat, json_normalize
from pytorch_fob.engine.configs import EvalConfig
from pytorch_fob.engine.grid_search import grid_search
from pytorch_fob.engine.parser import YAMLParser
from pytorch_fob.engine.run import Run
from pytorch_fob.engine.run_schedulers import sequential, slurm_array, slurm_jobs
from pytorch_fob.engine.utils import log_debug, log_info, log_warn, some, sort_dict_recursively
from pytorch_fob.evaluation import evaluation_path
from pytorch_fob.evaluation.plot import create_figure, get_output_file_path, save_files, set_plotstyle
from pytorch_fob.optimizers import lr_schedulers_path, optimizer_path, optimizer_names
from pytorch_fob.tasks import task_path, task_names
def engine_path() -> Path:
return Path(__file__).resolve().parent
class Engine():
def __init__(self) -> None:
self._runs = []
self._defaults = []
self._experiment = {}
self._experiment_file = None
self._block_plotting = False
self.task_key = "task"
self.optimizer_key = "optimizer"
self.engine_key = "engine"
self.eval_key = "evaluation"
self.identifier_key = "name"
self.default_file_name = "default.yaml"
self.parser = YAMLParser()
def run_experiment(self) -> Optional[list[int]]:
assert len(self._runs) > 0, "No runs in experiment, make sure to call 'parse_experiment' first."
scheduler = self._runs[0][self.engine_key]["run_scheduler"]
assert all(map(lambda x: x[self.engine_key]["run_scheduler"] == scheduler, self._runs)), \
"You cannot perform gridsearch on 'run_scheduler'."
if scheduler == "sequential":
sequential(self.runs(), len(self._runs), self._experiment)
elif scheduler.startswith("single"):
n = int(scheduler.rsplit(":", 1)[-1])
log_info(f"Starting run {n}/{len(self._runs)}.")
run = self._make_run(n)
run.start()
elif scheduler == "slurm_array":
self._block_plotting = True
slurm_array(list(self.runs()), self._experiment)
elif scheduler == "slurm_jobs":
self._block_plotting = True
return slurm_jobs(list(self.runs()), self._experiment)
else:
raise ValueError(f"Unsupported run_scheduler: {scheduler=}.")
def parse_experiment_from_file(self, file: Path, extra_args: Iterable[str] = tuple()):
self._experiment_file = file.resolve()
searchspace: dict[str, Any] = self.parser.parse_yaml(self._experiment_file)
self.parse_experiment(searchspace, extra_args)
def parse_experiment(self, searchspace: dict[str, Any], extra_args: Iterable[str] = tuple()):
self.parser.parse_args_into_searchspace(searchspace, extra_args)
# normalize experiment
self._named_dicts_to_list(
searchspace,
[self.optimizer_key, self.task_key],
[optimizer_names(), task_names()]
)
searchspace = sort_dict_recursively(searchspace)
self._experiment = deepcopy(searchspace)
# exclude plotting from gridsearch
if self.eval_key in searchspace:
eval_config = searchspace.pop(self.eval_key)
else:
eval_config = {}
log_debug("Performing gridsearch...")
self._runs = grid_search(searchspace)
log_debug(f"Found {len(self._runs)} runs.")
for run in self._runs:
run[self.eval_key] = eval_config
self._fill_runs_from_default(self._runs)
self._fill_defaults()
def runs(self) -> Iterator[Run]:
"""
Creates and initializes runs from parsed run config.
"""
for n, _ in enumerate(self._runs, start=1):
yield self._make_run(n)
def prepare_data(self):
prepared = set()
for n, t in enumerate(self._runs, start=1):
name = t["task"]["name"]
if name not in prepared:
run = self._make_run(n)
log_info(f"Setting up data for {run.task_key} '{run.task.name}'...")
run.get_datamodule().prepare_data()
log_info("... finished.")
prepared.add(name)
def plot(self, save: bool = True) -> list[Figure]:
run = next(self.runs())
if self._block_plotting or not run.engine.plot:
return []
config = run.evaluation
set_plotstyle(config)
figs = []
for mode in config.checkpoints:
df = self.dataframe_from_runs(mode)
if config.plot.single_file:
fig, dfs = self.plot_one_fig(df, config)
if save:
self.save_one_plot(fig, dfs, config, mode)
figs.append(fig)
else:
# TODO: option to split into multiple files
raise NotImplementedError("evaluation.plot.single_file=False is not implemented yet.")
return figs
def plot_one_fig(self, df: DataFrame, config: EvalConfig):
if config.column_split_key is None:
dfs = [df]
else:
groups = df.groupby(config.column_split_key)
order = some(config.column_split_order, default=map(lambda x: x[0], sorted(groups)))
dfs: list[DataFrame] = [groups.get_group(group_name) for group_name in order]
fig, _ = create_figure(dfs, config)
return fig, dfs
def save_one_plot(self, fig, dfs: list[DataFrame], config: EvalConfig, mode: Literal["last", "best"]):
output_file_path = get_output_file_path(dfs, config, suffix=mode)
save_files(fig, dfs, output_file_path, config)
def dataframe_from_runs(self, mode: Literal["last", "best"]) -> DataFrame:
dfs: list[DataFrame] = []
for run in self.runs():
df = json_normalize(run.get_config())
if mode == "last":
result_file = run.run_dir / run.evaluation.experiment_files.last_model
elif mode == "best":
result_file = run.run_dir / run.evaluation.experiment_files.best_model
else:
raise ValueError(f"mode {mode} not supported")
if not result_file.is_file():
log_warn(f"result file {result_file} not found, skipping this hyperparameter setting")
continue
metric = run.evaluation.plot.metric
with open(result_file, "r", encoding="utf8") as f:
content = json.load(f)
if metric in content[0]:
df.at[0, metric] = content[0][metric]
else:
log_warn(f"could not find value for {metric} in json, skipping this hyperparameter setting")
continue
dfs.append(df)
if len(dfs) == 0:
raise ValueError("no dataframes found, check your config")
return concat(dfs, sort=False)
def _make_run(self, n: int) -> Run:
"""
n: number of the run, starting from 1
setup: download and prepare data
"""
i = n - 1
return Run(
self._runs[i],
self._defaults[i],
self.task_key,
self.optimizer_key,
self.engine_key,
self.eval_key,
self.identifier_key
)
def _named_dicts_to_list(self, searchspace: dict[str, Any], keys: list[str], valid_options: list[list[str]]):
assert len(keys) == len(valid_options)
for key, opts in zip(keys, valid_options):
if key not in searchspace:
continue
if isinstance(searchspace[key], dict) and all(name in opts for name in searchspace[key]):
searchspace[key] = [cfg | {self.identifier_key: name} for name, cfg in searchspace[key].items()]
def _fill_defaults(self):
self._defaults = []
for run in self._runs:
default_cfg = {
k: {self.identifier_key: run[k][self.identifier_key]}
for k in [self.task_key, self.optimizer_key]
}
self._defaults.append(default_cfg)
self._fill_runs_from_default(self._defaults)
def _fill_runs_from_default(self, runs: list[dict[str, Any]]):
for i, _ in enumerate(runs):
# order from higher to lower in hierarchy
runs[i] = self._fill_named_from_default(runs[i], self.task_key, task_path)
runs[i] = self._fill_named_from_default(runs[i], self.optimizer_key, optimizer_path)
runs[i] = self._fill_unnamed_from_default(runs[i], lr_schedulers_path)
runs[i] = self._fill_unnamed_from_default(runs[i], engine_path)
runs[i] = self._fill_unnamed_from_default(runs[i], evaluation_path)
def _fill_unnamed_from_default(self, experiment: dict[str, Any], unnamed_root: Callable) -> dict[str, Any]:
default_path: Path = unnamed_root() / self.default_file_name
default_config = self.parser.parse_yaml(default_path)
self.parser.merge_dicts_hierarchical(default_config, experiment)
return default_config
def _fill_named_from_default(self, experiment: dict[str, Any], key: str, named_root: Callable) -> dict[str, Any]:
self._argcheck_named(experiment, key, self.identifier_key)
named = experiment[key]
if isinstance(named, dict):
named = named[self.identifier_key]
else:
experiment[key] = {self.identifier_key: named}
default_path: Path = named_root(named) / self.default_file_name
default_config = self.parser.parse_yaml(default_path)
self.parser.merge_dicts_hierarchical(default_config, experiment)
return default_config
def _argcheck_named(self, experiment: dict[str, Any], key: str, identifier: str):
assert key in experiment, f"You did not provide any {key}."
assert isinstance(experiment[key], str) or identifier in experiment[key], \
f"Unknown {key}, either specify only a string or provide a key '{identifier}'"