Convert FOB submodule to regular folder

This commit is contained in:
arihanv 2025-05-18 16:36:28 -07:00
parent 94f046ad40
commit 94825011a0
74 changed files with 4563 additions and 0 deletions

View file

@ -0,0 +1,228 @@
import json
from copy import deepcopy
from typing import Any, Callable, Iterable, Iterator, Literal, Optional
from pathlib import Path
from matplotlib.figure import Figure
from pandas import DataFrame, concat, json_normalize
from pytorch_fob.engine.configs import EvalConfig
from pytorch_fob.engine.grid_search import grid_search
from pytorch_fob.engine.parser import YAMLParser
from pytorch_fob.engine.run import Run
from pytorch_fob.engine.run_schedulers import sequential, slurm_array, slurm_jobs
from pytorch_fob.engine.utils import log_debug, log_info, log_warn, some, sort_dict_recursively
from pytorch_fob.evaluation import evaluation_path
from pytorch_fob.evaluation.plot import create_figure, get_output_file_path, save_files, set_plotstyle
from pytorch_fob.optimizers import lr_schedulers_path, optimizer_path, optimizer_names
from pytorch_fob.tasks import task_path, task_names
def engine_path() -> Path:
return Path(__file__).resolve().parent
class Engine():
def __init__(self) -> None:
self._runs = []
self._defaults = []
self._experiment = {}
self._experiment_file = None
self._block_plotting = False
self.task_key = "task"
self.optimizer_key = "optimizer"
self.engine_key = "engine"
self.eval_key = "evaluation"
self.identifier_key = "name"
self.default_file_name = "default.yaml"
self.parser = YAMLParser()
def run_experiment(self) -> Optional[list[int]]:
assert len(self._runs) > 0, "No runs in experiment, make sure to call 'parse_experiment' first."
scheduler = self._runs[0][self.engine_key]["run_scheduler"]
assert all(map(lambda x: x[self.engine_key]["run_scheduler"] == scheduler, self._runs)), \
"You cannot perform gridsearch on 'run_scheduler'."
if scheduler == "sequential":
sequential(self.runs(), len(self._runs), self._experiment)
elif scheduler.startswith("single"):
n = int(scheduler.rsplit(":", 1)[-1])
log_info(f"Starting run {n}/{len(self._runs)}.")
run = self._make_run(n)
run.start()
elif scheduler == "slurm_array":
self._block_plotting = True
slurm_array(list(self.runs()), self._experiment)
elif scheduler == "slurm_jobs":
self._block_plotting = True
return slurm_jobs(list(self.runs()), self._experiment)
else:
raise ValueError(f"Unsupported run_scheduler: {scheduler=}.")
def parse_experiment_from_file(self, file: Path, extra_args: Iterable[str] = tuple()):
self._experiment_file = file.resolve()
searchspace: dict[str, Any] = self.parser.parse_yaml(self._experiment_file)
self.parse_experiment(searchspace, extra_args)
def parse_experiment(self, searchspace: dict[str, Any], extra_args: Iterable[str] = tuple()):
self.parser.parse_args_into_searchspace(searchspace, extra_args)
# normalize experiment
self._named_dicts_to_list(
searchspace,
[self.optimizer_key, self.task_key],
[optimizer_names(), task_names()]
)
searchspace = sort_dict_recursively(searchspace)
self._experiment = deepcopy(searchspace)
# exclude plotting from gridsearch
if self.eval_key in searchspace:
eval_config = searchspace.pop(self.eval_key)
else:
eval_config = {}
log_debug("Performing gridsearch...")
self._runs = grid_search(searchspace)
log_debug(f"Found {len(self._runs)} runs.")
for run in self._runs:
run[self.eval_key] = eval_config
self._fill_runs_from_default(self._runs)
self._fill_defaults()
def runs(self) -> Iterator[Run]:
"""
Creates and initializes runs from parsed run config.
"""
for n, _ in enumerate(self._runs, start=1):
yield self._make_run(n)
def prepare_data(self):
prepared = set()
for n, t in enumerate(self._runs, start=1):
name = t["task"]["name"]
if name not in prepared:
run = self._make_run(n)
log_info(f"Setting up data for {run.task_key} '{run.task.name}'...")
run.get_datamodule().prepare_data()
log_info("... finished.")
prepared.add(name)
def plot(self, save: bool = True) -> list[Figure]:
run = next(self.runs())
if self._block_plotting or not run.engine.plot:
return []
config = run.evaluation
set_plotstyle(config)
figs = []
for mode in config.checkpoints:
df = self.dataframe_from_runs(mode)
if config.plot.single_file:
fig, dfs = self.plot_one_fig(df, config)
if save:
self.save_one_plot(fig, dfs, config, mode)
figs.append(fig)
else:
# TODO: option to split into multiple files
raise NotImplementedError("evaluation.plot.single_file=False is not implemented yet.")
return figs
def plot_one_fig(self, df: DataFrame, config: EvalConfig):
if config.column_split_key is None:
dfs = [df]
else:
groups = df.groupby(config.column_split_key)
order = some(config.column_split_order, default=map(lambda x: x[0], sorted(groups)))
dfs: list[DataFrame] = [groups.get_group(group_name) for group_name in order]
fig, _ = create_figure(dfs, config)
return fig, dfs
def save_one_plot(self, fig, dfs: list[DataFrame], config: EvalConfig, mode: Literal["last", "best"]):
output_file_path = get_output_file_path(dfs, config, suffix=mode)
save_files(fig, dfs, output_file_path, config)
def dataframe_from_runs(self, mode: Literal["last", "best"]) -> DataFrame:
dfs: list[DataFrame] = []
for run in self.runs():
df = json_normalize(run.get_config())
if mode == "last":
result_file = run.run_dir / run.evaluation.experiment_files.last_model
elif mode == "best":
result_file = run.run_dir / run.evaluation.experiment_files.best_model
else:
raise ValueError(f"mode {mode} not supported")
if not result_file.is_file():
log_warn(f"result file {result_file} not found, skipping this hyperparameter setting")
continue
metric = run.evaluation.plot.metric
with open(result_file, "r", encoding="utf8") as f:
content = json.load(f)
if metric in content[0]:
df.at[0, metric] = content[0][metric]
else:
log_warn(f"could not find value for {metric} in json, skipping this hyperparameter setting")
continue
dfs.append(df)
if len(dfs) == 0:
raise ValueError("no dataframes found, check your config")
return concat(dfs, sort=False)
def _make_run(self, n: int) -> Run:
"""
n: number of the run, starting from 1
setup: download and prepare data
"""
i = n - 1
return Run(
self._runs[i],
self._defaults[i],
self.task_key,
self.optimizer_key,
self.engine_key,
self.eval_key,
self.identifier_key
)
def _named_dicts_to_list(self, searchspace: dict[str, Any], keys: list[str], valid_options: list[list[str]]):
assert len(keys) == len(valid_options)
for key, opts in zip(keys, valid_options):
if key not in searchspace:
continue
if isinstance(searchspace[key], dict) and all(name in opts for name in searchspace[key]):
searchspace[key] = [cfg | {self.identifier_key: name} for name, cfg in searchspace[key].items()]
def _fill_defaults(self):
self._defaults = []
for run in self._runs:
default_cfg = {
k: {self.identifier_key: run[k][self.identifier_key]}
for k in [self.task_key, self.optimizer_key]
}
self._defaults.append(default_cfg)
self._fill_runs_from_default(self._defaults)
def _fill_runs_from_default(self, runs: list[dict[str, Any]]):
for i, _ in enumerate(runs):
# order from higher to lower in hierarchy
runs[i] = self._fill_named_from_default(runs[i], self.task_key, task_path)
runs[i] = self._fill_named_from_default(runs[i], self.optimizer_key, optimizer_path)
runs[i] = self._fill_unnamed_from_default(runs[i], lr_schedulers_path)
runs[i] = self._fill_unnamed_from_default(runs[i], engine_path)
runs[i] = self._fill_unnamed_from_default(runs[i], evaluation_path)
def _fill_unnamed_from_default(self, experiment: dict[str, Any], unnamed_root: Callable) -> dict[str, Any]:
default_path: Path = unnamed_root() / self.default_file_name
default_config = self.parser.parse_yaml(default_path)
self.parser.merge_dicts_hierarchical(default_config, experiment)
return default_config
def _fill_named_from_default(self, experiment: dict[str, Any], key: str, named_root: Callable) -> dict[str, Any]:
self._argcheck_named(experiment, key, self.identifier_key)
named = experiment[key]
if isinstance(named, dict):
named = named[self.identifier_key]
else:
experiment[key] = {self.identifier_key: named}
default_path: Path = named_root(named) / self.default_file_name
default_config = self.parser.parse_yaml(default_path)
self.parser.merge_dicts_hierarchical(default_config, experiment)
return default_config
def _argcheck_named(self, experiment: dict[str, Any], key: str, identifier: str):
assert key in experiment, f"You did not provide any {key}."
assert isinstance(experiment[key], str) or identifier in experiment[key], \
f"Unknown {key}, either specify only a string or provide a key '{identifier}'"