mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
228 lines
10 KiB
Python
228 lines
10 KiB
Python
import json
|
|
from copy import deepcopy
|
|
from typing import Any, Callable, Iterable, Iterator, Literal, Optional
|
|
from pathlib import Path
|
|
from matplotlib.figure import Figure
|
|
from pandas import DataFrame, concat, json_normalize
|
|
from pytorch_fob.engine.configs import EvalConfig
|
|
from pytorch_fob.engine.grid_search import grid_search
|
|
from pytorch_fob.engine.parser import YAMLParser
|
|
from pytorch_fob.engine.run import Run
|
|
from pytorch_fob.engine.run_schedulers import sequential, slurm_array, slurm_jobs
|
|
from pytorch_fob.engine.utils import log_debug, log_info, log_warn, some, sort_dict_recursively
|
|
from pytorch_fob.evaluation import evaluation_path
|
|
from pytorch_fob.evaluation.plot import create_figure, get_output_file_path, save_files, set_plotstyle
|
|
from pytorch_fob.optimizers import lr_schedulers_path, optimizer_path, optimizer_names
|
|
from pytorch_fob.tasks import task_path, task_names
|
|
|
|
|
|
def engine_path() -> Path:
|
|
return Path(__file__).resolve().parent
|
|
|
|
|
|
class Engine():
|
|
def __init__(self) -> None:
|
|
self._runs = []
|
|
self._defaults = []
|
|
self._experiment = {}
|
|
self._experiment_file = None
|
|
self._block_plotting = False
|
|
self.task_key = "task"
|
|
self.optimizer_key = "optimizer"
|
|
self.engine_key = "engine"
|
|
self.eval_key = "evaluation"
|
|
self.identifier_key = "name"
|
|
self.default_file_name = "default.yaml"
|
|
self.parser = YAMLParser()
|
|
|
|
def run_experiment(self) -> Optional[list[int]]:
|
|
assert len(self._runs) > 0, "No runs in experiment, make sure to call 'parse_experiment' first."
|
|
scheduler = self._runs[0][self.engine_key]["run_scheduler"]
|
|
assert all(map(lambda x: x[self.engine_key]["run_scheduler"] == scheduler, self._runs)), \
|
|
"You cannot perform gridsearch on 'run_scheduler'."
|
|
if scheduler == "sequential":
|
|
sequential(self.runs(), len(self._runs), self._experiment)
|
|
elif scheduler.startswith("single"):
|
|
n = int(scheduler.rsplit(":", 1)[-1])
|
|
log_info(f"Starting run {n}/{len(self._runs)}.")
|
|
run = self._make_run(n)
|
|
run.start()
|
|
elif scheduler == "slurm_array":
|
|
self._block_plotting = True
|
|
slurm_array(list(self.runs()), self._experiment)
|
|
elif scheduler == "slurm_jobs":
|
|
self._block_plotting = True
|
|
return slurm_jobs(list(self.runs()), self._experiment)
|
|
else:
|
|
raise ValueError(f"Unsupported run_scheduler: {scheduler=}.")
|
|
|
|
def parse_experiment_from_file(self, file: Path, extra_args: Iterable[str] = tuple()):
|
|
self._experiment_file = file.resolve()
|
|
searchspace: dict[str, Any] = self.parser.parse_yaml(self._experiment_file)
|
|
self.parse_experiment(searchspace, extra_args)
|
|
|
|
def parse_experiment(self, searchspace: dict[str, Any], extra_args: Iterable[str] = tuple()):
|
|
self.parser.parse_args_into_searchspace(searchspace, extra_args)
|
|
# normalize experiment
|
|
self._named_dicts_to_list(
|
|
searchspace,
|
|
[self.optimizer_key, self.task_key],
|
|
[optimizer_names(), task_names()]
|
|
)
|
|
searchspace = sort_dict_recursively(searchspace)
|
|
self._experiment = deepcopy(searchspace)
|
|
# exclude plotting from gridsearch
|
|
if self.eval_key in searchspace:
|
|
eval_config = searchspace.pop(self.eval_key)
|
|
else:
|
|
eval_config = {}
|
|
log_debug("Performing gridsearch...")
|
|
self._runs = grid_search(searchspace)
|
|
log_debug(f"Found {len(self._runs)} runs.")
|
|
for run in self._runs:
|
|
run[self.eval_key] = eval_config
|
|
self._fill_runs_from_default(self._runs)
|
|
self._fill_defaults()
|
|
|
|
def runs(self) -> Iterator[Run]:
|
|
"""
|
|
Creates and initializes runs from parsed run config.
|
|
"""
|
|
for n, _ in enumerate(self._runs, start=1):
|
|
yield self._make_run(n)
|
|
|
|
def prepare_data(self):
|
|
prepared = set()
|
|
for n, t in enumerate(self._runs, start=1):
|
|
name = t["task"]["name"]
|
|
if name not in prepared:
|
|
run = self._make_run(n)
|
|
log_info(f"Setting up data for {run.task_key} '{run.task.name}'...")
|
|
run.get_datamodule().prepare_data()
|
|
log_info("... finished.")
|
|
prepared.add(name)
|
|
|
|
def plot(self, save: bool = True) -> list[Figure]:
|
|
run = next(self.runs())
|
|
if self._block_plotting or not run.engine.plot:
|
|
return []
|
|
config = run.evaluation
|
|
set_plotstyle(config)
|
|
figs = []
|
|
for mode in config.checkpoints:
|
|
df = self.dataframe_from_runs(mode)
|
|
if config.plot.single_file:
|
|
fig, dfs = self.plot_one_fig(df, config)
|
|
if save:
|
|
self.save_one_plot(fig, dfs, config, mode)
|
|
figs.append(fig)
|
|
else:
|
|
# TODO: option to split into multiple files
|
|
raise NotImplementedError("evaluation.plot.single_file=False is not implemented yet.")
|
|
return figs
|
|
|
|
def plot_one_fig(self, df: DataFrame, config: EvalConfig):
|
|
if config.column_split_key is None:
|
|
dfs = [df]
|
|
else:
|
|
groups = df.groupby(config.column_split_key)
|
|
order = some(config.column_split_order, default=map(lambda x: x[0], sorted(groups)))
|
|
dfs: list[DataFrame] = [groups.get_group(group_name) for group_name in order]
|
|
fig, _ = create_figure(dfs, config)
|
|
return fig, dfs
|
|
|
|
def save_one_plot(self, fig, dfs: list[DataFrame], config: EvalConfig, mode: Literal["last", "best"]):
|
|
output_file_path = get_output_file_path(dfs, config, suffix=mode)
|
|
save_files(fig, dfs, output_file_path, config)
|
|
|
|
def dataframe_from_runs(self, mode: Literal["last", "best"]) -> DataFrame:
|
|
dfs: list[DataFrame] = []
|
|
for run in self.runs():
|
|
df = json_normalize(run.get_config())
|
|
if mode == "last":
|
|
result_file = run.run_dir / run.evaluation.experiment_files.last_model
|
|
elif mode == "best":
|
|
result_file = run.run_dir / run.evaluation.experiment_files.best_model
|
|
else:
|
|
raise ValueError(f"mode {mode} not supported")
|
|
if not result_file.is_file():
|
|
log_warn(f"result file {result_file} not found, skipping this hyperparameter setting")
|
|
continue
|
|
metric = run.evaluation.plot.metric
|
|
with open(result_file, "r", encoding="utf8") as f:
|
|
content = json.load(f)
|
|
if metric in content[0]:
|
|
df.at[0, metric] = content[0][metric]
|
|
else:
|
|
log_warn(f"could not find value for {metric} in json, skipping this hyperparameter setting")
|
|
continue
|
|
dfs.append(df)
|
|
if len(dfs) == 0:
|
|
raise ValueError("no dataframes found, check your config")
|
|
return concat(dfs, sort=False)
|
|
|
|
def _make_run(self, n: int) -> Run:
|
|
"""
|
|
n: number of the run, starting from 1
|
|
setup: download and prepare data
|
|
"""
|
|
i = n - 1
|
|
return Run(
|
|
self._runs[i],
|
|
self._defaults[i],
|
|
self.task_key,
|
|
self.optimizer_key,
|
|
self.engine_key,
|
|
self.eval_key,
|
|
self.identifier_key
|
|
)
|
|
|
|
def _named_dicts_to_list(self, searchspace: dict[str, Any], keys: list[str], valid_options: list[list[str]]):
|
|
assert len(keys) == len(valid_options)
|
|
for key, opts in zip(keys, valid_options):
|
|
if key not in searchspace:
|
|
continue
|
|
if isinstance(searchspace[key], dict) and all(name in opts for name in searchspace[key]):
|
|
searchspace[key] = [cfg | {self.identifier_key: name} for name, cfg in searchspace[key].items()]
|
|
|
|
def _fill_defaults(self):
|
|
self._defaults = []
|
|
for run in self._runs:
|
|
default_cfg = {
|
|
k: {self.identifier_key: run[k][self.identifier_key]}
|
|
for k in [self.task_key, self.optimizer_key]
|
|
}
|
|
self._defaults.append(default_cfg)
|
|
self._fill_runs_from_default(self._defaults)
|
|
|
|
def _fill_runs_from_default(self, runs: list[dict[str, Any]]):
|
|
for i, _ in enumerate(runs):
|
|
# order from higher to lower in hierarchy
|
|
runs[i] = self._fill_named_from_default(runs[i], self.task_key, task_path)
|
|
runs[i] = self._fill_named_from_default(runs[i], self.optimizer_key, optimizer_path)
|
|
runs[i] = self._fill_unnamed_from_default(runs[i], lr_schedulers_path)
|
|
runs[i] = self._fill_unnamed_from_default(runs[i], engine_path)
|
|
runs[i] = self._fill_unnamed_from_default(runs[i], evaluation_path)
|
|
|
|
def _fill_unnamed_from_default(self, experiment: dict[str, Any], unnamed_root: Callable) -> dict[str, Any]:
|
|
default_path: Path = unnamed_root() / self.default_file_name
|
|
default_config = self.parser.parse_yaml(default_path)
|
|
self.parser.merge_dicts_hierarchical(default_config, experiment)
|
|
return default_config
|
|
|
|
def _fill_named_from_default(self, experiment: dict[str, Any], key: str, named_root: Callable) -> dict[str, Any]:
|
|
self._argcheck_named(experiment, key, self.identifier_key)
|
|
named = experiment[key]
|
|
if isinstance(named, dict):
|
|
named = named[self.identifier_key]
|
|
else:
|
|
experiment[key] = {self.identifier_key: named}
|
|
default_path: Path = named_root(named) / self.default_file_name
|
|
default_config = self.parser.parse_yaml(default_path)
|
|
self.parser.merge_dicts_hierarchical(default_config, experiment)
|
|
return default_config
|
|
|
|
def _argcheck_named(self, experiment: dict[str, Any], key: str, identifier: str):
|
|
assert key in experiment, f"You did not provide any {key}."
|
|
assert isinstance(experiment[key], str) or identifier in experiment[key], \
|
|
f"Unknown {key}, either specify only a string or provide a key '{identifier}'"
|