Convert FOB submodule to regular folder

This commit is contained in:
arihanv 2025-05-18 16:36:28 -07:00
parent 94f046ad40
commit 94825011a0
74 changed files with 4563 additions and 0 deletions

View file

@ -0,0 +1,7 @@
from pathlib import Path
from pytorch_fob.engine.engine import Engine
def repository_root() -> Path:
return Path(__file__).resolve().parent.parent

View file

@ -0,0 +1,272 @@
import math
import time
from typing import Iterable, Optional
import deepspeed
import torch
from lightning import Callback, LightningModule, Trainer
from lightning_utilities.core.rank_zero import rank_zero_only
from torch.linalg import vector_norm
from pytorch_fob.engine.utils import log_debug, log_info, log_warn, seconds_to_str
class RestrictTrainEpochs(Callback):
"""Counts number of epochs since start of training and stops if max_epochs is reached."""
def __init__(self, max_epochs: int):
super().__init__()
self.max_epochs = max_epochs
self.epochs = 0
self.skip_first = False
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
log_debug(f"Training for {self.max_epochs} epochs...")
self.epochs = 0
trainer.should_stop = False
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
if self.skip_first:
self.skip_first = False
else:
self.epochs += 1
log_debug(f"Epoch {self.epochs}/{self.max_epochs}")
# TODO: test for DDP, do we need 'trainer.strategy.reduce_boolean_decision'?
if self.epochs >= self.max_epochs:
log_debug(f"Stopping training after {self.epochs} epochs")
trainer.should_stop = True
def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint):
# checkpoint loads the model at the end of the epoch, so we do not count the first epoch
self.skip_first = True
class OptimizerTime(Callback):
def __init__(self):
super().__init__()
self.total_mean_optimizer_step_time_ms: float = 0.0
self.total_epochs: int = 0
def on_train_epoch_end(self, trainer, pl_module):
if len(pl_module.optimizer_times_ms) == 0:
return
epoch_mean = sum(pl_module.optimizer_times_ms) / len(pl_module.optimizer_times_ms)
pl_module.log("mean_optimizer_step_time_ms", epoch_mean, on_step=False, on_epoch=True, sync_dist=True)
# Update the running mean
self.total_epochs += 1
self.total_mean_optimizer_step_time_ms = (
(self.total_mean_optimizer_step_time_ms * (self.total_epochs - 1)) + epoch_mean
) / self.total_epochs
# Reset the optimizer step times for the next epoch
pl_module.optimizer_times_ms = [] # type: ignore
def state_dict(self) -> dict[str, float | int]:
return {"running_mean": self.total_mean_optimizer_step_time_ms, "total_epochs": self.total_epochs}
def load_state_dict(self, state_dict: dict[str, float | int]):
self.total_mean_optimizer_step_time_ms = state_dict["running_mean"]
self.total_epochs = state_dict["total_epochs"] # type: ignore
class PrintEpochWithTime(Callback):
def __init__(self, active: bool = True):
super().__init__()
self.active: bool = active
self.time: dict[str, Optional[float]]
self.reset_time()
def reset_time(self):
self.time = {"train_start": None, "val_start": None, "val_end": None}
@rank_zero_only
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
if self.active:
self.time["train_start"] = time.time()
@rank_zero_only
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
# need to print here since train epoch ends after validation is done
if self.active and all(v is not None for v in self.time.values()):
max_epochs = pl_module.config.max_epochs
train_time = math.ceil(time.time() - self.time["train_start"]) # type: ignore
val_time = math.ceil(self.time["val_end"] - self.time["val_start"]) # type: ignore
log_info(
f"Finished training epoch {trainer.current_epoch + 1} of {max_epochs}. Time spent: training: {seconds_to_str(train_time - val_time)}, validation: {seconds_to_str(val_time)}, total: {seconds_to_str(train_time)}."
)
self.reset_time()
@rank_zero_only
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
if self.active:
self.time["val_start"] = time.time()
@rank_zero_only
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
if self.active:
self.time["val_end"] = time.time()
def metric_fn(metric: str, v: torch.Tensor, override: Optional[float] = None) -> float:
if override is not None:
return override
match metric:
case "mean":
return v.mean().item()
case "sum":
return v.sum().item()
case "abs_mean":
return v.abs().mean().item()
case "std":
return v.std().item()
case "abs_std":
return v.abs().std().item()
case "min":
return v.min().item()
case "max":
return v.max().item()
case "l1":
return vector_norm(v, ord=1).item()
case "l2":
return vector_norm(v, ord=2).item()
case "sq_mean":
return (v**2).mean().item()
case "sq_sum":
return (v**2).sum().item()
case _:
raise ValueError(f"unknown metric {metric}")
def add_metrics_to_stats(
stats: dict[str, float],
prefix: str,
name: str,
v: torch.Tensor,
metrics: Iterable[str],
override: Optional[float] = None,
):
for metric in metrics:
stats[f"{prefix}/{name}/{metric}"] = metric_fn(metric, v, override=override)
class LogTrainingStats(Callback):
def __init__(
self,
log_gradient: bool = True,
log_params: bool = True,
log_quantiles: bool = False,
log_momentum: bool = False,
log_lrs: bool = True,
log_every_n_steps: int = 50,
change_log_interval_every_n_steps: Optional[int] = None,
log_interval_factor: float = 2.0,
min_log_interval: int = 1,
max_log_interval: Optional[int] = None,
metrics: Iterable[str] = ("mean", "abs_mean", "std", "abs_std", "min", "max", "l1", "l2", "sq_mean"),
):
super().__init__()
self.log_gradient = log_gradient
self.log_params = log_params
self.log_quantiles = log_quantiles
self.log_momentum = log_momentum
self.log_lrs = log_lrs
self.log_every_n_steps = log_every_n_steps
self.change_log_interval_every_n_steps = change_log_interval_every_n_steps
self.log_interval_factor = log_interval_factor
self.min_log_interval = min_log_interval
self.max_log_interval = max_log_interval
self.metrics = metrics
def _check_and_adjust_log_interval(self, trainer: Trainer, pl_module: LightningModule):
if self.change_log_interval_every_n_steps is not None:
if trainer.global_step > 0 and trainer.global_step % self.change_log_interval_every_n_steps == 0:
self.log_every_n_steps = math.ceil(self.log_every_n_steps * self.log_interval_factor)
self.log_every_n_steps = max(self.log_every_n_steps, self.min_log_interval)
if self.max_log_interval is not None:
self.log_every_n_steps = min(self.log_every_n_steps, self.max_log_interval)
pl_module.log("logging_interval", self.log_every_n_steps)
return trainer.global_step % self.log_every_n_steps == 0
@rank_zero_only
def on_before_optimizer_step(self, trainer: Trainer, pl_module: LightningModule, optimizer: torch.optim.Optimizer):
if self._check_and_adjust_log_interval(trainer, pl_module):
stats = {}
q = torch.arange(0.25, 1, 0.25).round(decimals=2).to(trainer.model.device)
for param_group in optimizer.param_groups:
for name, param in zip(param_group["names"], param_group["params"]):
if self.log_params or self.log_lrs:
v_detached = param.detach()
if self.log_params:
if torch.isnan(v_detached).sum() > 0:
log_warn(f"# NaN in param {name}")
if torch.isinf(v_detached).sum() > 0:
log_warn(f"# Inf in param {name}")
add_metrics_to_stats(stats, "param", name, v_detached, self.metrics)
if self.log_quantiles and v_detached.size().numel() < 10000000:
deciles = torch.quantile(v_detached.float(), q, interpolation="linear")
for q_idx, d_val in enumerate(deciles):
stats[f"param/{name}/quantile-{q[q_idx]}"] = d_val.item()
if (self.log_gradient or self.log_lrs) and param.requires_grad:
if trainer.num_devices > 1:
grad_data = deepspeed.utils.safe_get_full_grad(param)
else:
grad_data = param.grad
else:
grad_data = None
if grad_data is not None:
if torch.isnan(grad_data).sum() > 0:
log_warn(f"# NaN in grad {name}")
if torch.isinf(grad_data).sum() > 0:
log_warn(f"# Inf in grad {name}")
if self.log_gradient:
if torch.isnan(grad_data).sum() > 0 or torch.isinf(grad_data).sum() > 0:
add_metrics_to_stats(stats, "grad", name, grad_data, self.metrics, override=-10.0)
if self.log_quantiles and grad_data.size().numel() < 10000000:
for q_idx, _ in enumerate(q):
stats[f"param/{name}/quantile-{q[q_idx]}"] = -10
stats[f"grad/{name}/mean"] = grad_data.mean().item()
if len(grad_data.shape) > 1 or grad_data.shape[0] > 1:
add_metrics_to_stats(stats, "grad", name, grad_data, self.metrics)
if self.log_quantiles and grad_data.size().numel() < 10000000:
deciles = torch.quantile(grad_data.float(), q, interpolation="linear")
for q_idx, d_val in enumerate(deciles):
stats[f"grad/{name}/quantile-{q[q_idx]}"] = d_val.item()
if self.log_lrs:
grad_norm = vector_norm(grad_data)
param_norm = vector_norm(v_detached)
effective_lr = (grad_norm / param_norm).item() if param_norm != 0 else 0.0
stats[f"param/{name}/effective_lr"] = effective_lr
if self.log_momentum or self.log_lrs:
if param in optimizer.state:
state = optimizer.state[param]
else:
state = {}
if self.log_momentum:
if "exp_avg" in state:
moment1 = state["exp_avg"]
elif "momentum_buffer" in state:
moment1 = state["momentum_buffer"]
else:
moment1 = None
if moment1 is not None:
add_metrics_to_stats(stats, "1st_order_momentum", name, moment1, self.metrics)
if "exp_avg_sq" in state:
add_metrics_to_stats(stats, "2nd_order_momentum", name, state["exp_avg_sq"], self.metrics)
if self.log_lrs and "lr" in state:
stats[f"param/{name}/lr"] = state["lr"].item()
if trainer.loggers is not None:
for logger in trainer.loggers:
logger.log_metrics(stats, step=trainer.global_step)

View file

@ -0,0 +1,156 @@
from pathlib import Path
from typing import Any, Literal, Optional
from .utils import AttributeDict, EndlessList, convert_type_inside_dict, maybe_abspath, some, wrap_list
class BaseConfig(AttributeDict):
def __init__(self, config: dict):
super().__init__(convert_type_inside_dict(config, dict, AttributeDict))
class NamedConfig(BaseConfig):
def __init__(
self,
config: dict[str, Any],
identifier_key: str = "name",
outdir_key: str = "output_dir_name"
) -> None:
super().__init__(config)
self.name = config[identifier_key]
self.output_dir_name = config.get(outdir_key, self.name)
class OptimizerConfig(NamedConfig):
def __init__(
self,
config: dict[str, Any],
optimizer_key: str,
task_key: str,
identifier_key: str = "name",
outdir_key: str = "output_dir_name"
) -> None:
cfg = dict(config[optimizer_key])
self.lr_interval: Literal["step", "epoch"] = cfg.get("lr_interval", "step")
self.max_steps: int = config[task_key].get("max_steps", None)
self.max_epochs: int = config[task_key]["max_epochs"]
cfg["max_steps"] = self.max_steps
cfg["max_epochs"] = self.max_epochs
super().__init__(cfg, identifier_key, outdir_key)
class TaskConfig(NamedConfig):
def __init__(
self,
config: dict[str, Any],
task_key: str,
engine_key: str,
identifier_key: str = "name",
outdir_key: str = "output_dir_name"
) -> None:
cfg = dict(config[task_key])
self.batch_size: int = cfg["batch_size"]
self.data_dir = Path(config[engine_key]["data_dir"]).resolve()
self.max_epochs: int = cfg["max_epochs"]
self.max_steps: int = cfg.get("max_steps", None)
self.target_metric: str = cfg["target_metric"]
self.target_metric_mode: str = cfg["target_metric_mode"]
self.workers = config[engine_key]["workers"]
cfg["data_dir"] = self.data_dir
cfg["workers"] = self.workers
super().__init__(cfg, identifier_key, outdir_key)
class EngineConfig(BaseConfig):
def __init__(self, config: dict[str, Any], task_key: str, engine_key: str) -> None:
cfg = dict(config[engine_key])
self.accelerator = cfg["accelerator"]
self.deterministic: bool | Literal["warn"] = cfg["deterministic"]
self.data_dir = Path(cfg["data_dir"]).resolve()
self.detect_anomaly: bool = cfg["detect_anomaly"]
self.devices: int = some(cfg["devices"], default=1)
self.early_stopping: Optional[int] = cfg["early_stopping"]
self.early_stopping_metric: str = some(cfg["early_stopping_metric"], default=config[task_key]["target_metric"])
self.gradient_clip_alg: str = cfg["gradient_clip_alg"]
self.gradient_clip_val: Optional[float] = cfg["gradient_clip_val"]
self.log_extra: bool | dict[str, bool] = cfg["log_extra"]
self.logging_inteval: int = cfg["logging_interval"]
self.max_steps: int = config[task_key].get("max_steps", None)
self.optimize_memory: bool = cfg["optimize_memory"]
self.output_dir = Path(cfg["output_dir"]).resolve()
self.plot: bool = cfg["plot"]
self.precision: str = cfg["precision"]
self.restrict_train_epochs: Optional[int] = cfg["restrict_train_epochs"]
_resume = cfg.get("resume", False)
self.resume: Optional[Path] | bool = Path(_resume).resolve() if isinstance(_resume, str) else _resume
self.run_scheduler: str = cfg["run_scheduler"]
self.seed: int = cfg["seed"]
self.seed_mode: str = cfg["seed_mode"]
self.save_sbatch_scripts: Optional[Path] = maybe_abspath(cfg["save_sbatch_scripts"])
self.sbatch_args: dict[str, str] = cfg["sbatch_args"]
self.sbatch_script_template: Optional[Path] = maybe_abspath(cfg["sbatch_script_template"])
self.sbatch_time_factor: float = cfg["sbatch_time_factor"]
self.slurm_log_dir: Optional[Path] = maybe_abspath(cfg["slurm_log_dir"])
self.silent: bool = cfg.get("silent", False)
self.test: bool = cfg.get("test", True)
self.train: bool = cfg.get("train", True)
self.validate: bool = cfg.get("validate", False)
self.workers: int = cfg["workers"]
cfg["data_dir"] = self.data_dir
cfg["devices"] = self.devices
cfg["early_stopping_metric"] = self.early_stopping_metric
cfg["max_steps"] = self.max_steps
cfg["output_dir"] = self.output_dir
cfg["resume"] = self.resume
cfg["slurm_log_dir"] = self.slurm_log_dir
cfg["save_sbatch_scripts"] = self.save_sbatch_scripts
cfg["sbatch_script_template"] = self.sbatch_script_template
super().__init__(cfg)
def outpath_relevant_engine_keys(self, prefix: str = "") -> list[str]:
keys = [
"accelerator",
"deterministic",
"detect_anomaly",
"devices",
"early_stopping",
"gradient_clip_alg",
"gradient_clip_val",
"optimize_memory",
"precision",
"seed"
]
return [f"{prefix}{k}" for k in keys]
def outpath_irrelevant_engine_keys(self, prefix: str = "") -> list[str]:
return [f"{prefix}{k}" for k in self.keys() if k not in self.outpath_relevant_engine_keys()]
class EvalConfig(BaseConfig):
def __init__(self, config: dict[str, Any], eval_key: str, engine_key: str, ignore_keys = None) -> None:
cfg = dict(config[eval_key])
self.experiment_files = AttributeDict(dict(
best_model = "results_best_model.json",
last_model = "results_final_model.json",
config = "config.yaml"
))
self.output_types: list[str] = wrap_list(cfg["output_types"])
experiment_dir = Path(config[engine_key]["output_dir"]).resolve()
self.output_dir: Path = some(maybe_abspath(cfg["output_dir"]), default=experiment_dir / "plots")
self.experiment_name: str = cfg["experiment_name"]
self.verbose: bool = cfg.get("verbose", False)
split = cfg.get("split_groups", False)
self.split_groups: bool | list[str] = split if isinstance(split, bool) else wrap_list(split)
self.checkpoints: list[Literal["last", "best"]] = wrap_list(cfg["checkpoints"])
self.column_split_key: Optional[str] = cfg.get("column_split_key", None)
self.column_split_order: Optional[list[str]] = cfg.get("column_split_order", None)
self.ignore_keys: list[str] = some(ignore_keys, default=[])
self.aggregate_groups: list[str] = wrap_list(cfg["aggregate_groups"])
cfg["ignore_keys"] = self.ignore_keys
cfg["output_types"] = self.output_types
cfg["output_dir"] = self.output_dir
cfg["aggregate_groups"] = self.aggregate_groups
cfg["output_types"] = self.output_types
cfg["plot"]["x_axis"] = EndlessList(wrap_list(cfg["plot"]["x_axis"]))
cfg["plot"]["y_axis"] = EndlessList(wrap_list(cfg["plot"]["y_axis"]))
cfg["split_groups"] = self.split_groups
super().__init__(cfg)

View file

@ -0,0 +1,41 @@
engine:
accelerator: gpu # Whether to train on cpu or gpu
check_finite: true # Check if 'early_stopping_metric' is finite during training. Aborts training if not. Only active when 'early_stopping' is not null.
data_dir: ./data # Where you want to store the training data
deterministic: warn # 'warn' tries to use deterministic algorithms if possible, also accepts true or false.
detect_anomaly: false # Lightning trainer argument with same name.
devices: null # This is set by each task by default, but can be overridden
early_stopping: null # The number of epochs to wait before stopping if no improvement is found. Set to null to disable.
early_stopping_metric: null # Metric to use for early stopping. If null, uses 'task.target_metric'.
gradient_clip_alg: norm # {value, norm} to disable gradient clipping: set 'gradient_clip_val' to null
gradient_clip_val: null # DEFAULT: don't clip gradients, expects value in [0, 1]
log_extra: false # Activate logging of gradients and more. Can be bool or a dict with the options supported by callback `LogTrainingStats` in `pytorch_fob/engine/callbacks.py`.
logging_interval: 50 # Number of steps between each logging step.
optimize_memory: false # Use nondeterministic, but memory-efficient algorithms for self-attention
output_dir: ./experiments # Where you want to store the results
plot: true # Whether to plot the results.
precision: bf16-mixed # Floating precision of training, see https://lightning.ai/docs/pytorch/stable/common/precision_basic.html
restrict_train_epochs: null # Only train for a specific number of epochs. Set to null to disable. The epochs set here are counted from start of training, so this works with 'resume'.
resume: true # You can either pass the path to your checkpoint here or set to true, which loads the last checkpoint.
run_scheduler: sequential # How to schedule the runs of the experiment. Supported values:
# 'sequential': runs are performed sequentially
# 'single:N' where N is the number of the run starting from 1.
# 'slurm_array': runs are scheduled using a SLURM array job.
# 'slurm_jobs': runs are scheduled using independent SLURM jobs
save_sbatch_scripts: null # Path to directory where sbatch scripts will be saved. If null, sbatch scripts will not be saved.
sbatch_time_factor: 1 # Time factor for SLURM. Multiplies all default times by this factor.
sbatch_args: # Additional arguments to pass to sbatch. Only used if run_scheduler is 'slurm_array'.
# ntasks-per-node and gres are set to 'devices' by default
# cpus-per-task is set to 'workers' by default
nodes: 1
mem-per-cpu: 2gb
time: 00:30:00 # Each task has their own default time (assumes A100 or similar gpu). Format: HH:MM:SS or seconds.
sbatch_script_template: null # Path to template for the sbatch script. Script can contain placeholder '__FOB_COMMAND__'. Otherwise it will be executed before the experiment. 'sbatch_args' will be added to the beginning of the script.
slurm_log_dir: null # Default: 'output_dir/slurm_logs' for run_scheduler 'slurm_array' and 'run_dir/slurm_logs' for run_scheduler 'slurm_jobs'
seed: 42 # The seed to use for the experiment
seed_mode: fixed # Currently only supports 'fixed'
silent: false # whether to hide progress bars. Recommended when writing outputs to a log file.
test: true # Whether to test the model.
train: true # Whether to train the model.
validate: false # Whether to validate the model after training (only useful if you are interested in the results, for example for HPO).
workers: 16 # The number of processes to use for dataloading

View file

@ -0,0 +1,228 @@
import json
from copy import deepcopy
from typing import Any, Callable, Iterable, Iterator, Literal, Optional
from pathlib import Path
from matplotlib.figure import Figure
from pandas import DataFrame, concat, json_normalize
from pytorch_fob.engine.configs import EvalConfig
from pytorch_fob.engine.grid_search import grid_search
from pytorch_fob.engine.parser import YAMLParser
from pytorch_fob.engine.run import Run
from pytorch_fob.engine.run_schedulers import sequential, slurm_array, slurm_jobs
from pytorch_fob.engine.utils import log_debug, log_info, log_warn, some, sort_dict_recursively
from pytorch_fob.evaluation import evaluation_path
from pytorch_fob.evaluation.plot import create_figure, get_output_file_path, save_files, set_plotstyle
from pytorch_fob.optimizers import lr_schedulers_path, optimizer_path, optimizer_names
from pytorch_fob.tasks import task_path, task_names
def engine_path() -> Path:
return Path(__file__).resolve().parent
class Engine():
def __init__(self) -> None:
self._runs = []
self._defaults = []
self._experiment = {}
self._experiment_file = None
self._block_plotting = False
self.task_key = "task"
self.optimizer_key = "optimizer"
self.engine_key = "engine"
self.eval_key = "evaluation"
self.identifier_key = "name"
self.default_file_name = "default.yaml"
self.parser = YAMLParser()
def run_experiment(self) -> Optional[list[int]]:
assert len(self._runs) > 0, "No runs in experiment, make sure to call 'parse_experiment' first."
scheduler = self._runs[0][self.engine_key]["run_scheduler"]
assert all(map(lambda x: x[self.engine_key]["run_scheduler"] == scheduler, self._runs)), \
"You cannot perform gridsearch on 'run_scheduler'."
if scheduler == "sequential":
sequential(self.runs(), len(self._runs), self._experiment)
elif scheduler.startswith("single"):
n = int(scheduler.rsplit(":", 1)[-1])
log_info(f"Starting run {n}/{len(self._runs)}.")
run = self._make_run(n)
run.start()
elif scheduler == "slurm_array":
self._block_plotting = True
slurm_array(list(self.runs()), self._experiment)
elif scheduler == "slurm_jobs":
self._block_plotting = True
return slurm_jobs(list(self.runs()), self._experiment)
else:
raise ValueError(f"Unsupported run_scheduler: {scheduler=}.")
def parse_experiment_from_file(self, file: Path, extra_args: Iterable[str] = tuple()):
self._experiment_file = file.resolve()
searchspace: dict[str, Any] = self.parser.parse_yaml(self._experiment_file)
self.parse_experiment(searchspace, extra_args)
def parse_experiment(self, searchspace: dict[str, Any], extra_args: Iterable[str] = tuple()):
self.parser.parse_args_into_searchspace(searchspace, extra_args)
# normalize experiment
self._named_dicts_to_list(
searchspace,
[self.optimizer_key, self.task_key],
[optimizer_names(), task_names()]
)
searchspace = sort_dict_recursively(searchspace)
self._experiment = deepcopy(searchspace)
# exclude plotting from gridsearch
if self.eval_key in searchspace:
eval_config = searchspace.pop(self.eval_key)
else:
eval_config = {}
log_debug("Performing gridsearch...")
self._runs = grid_search(searchspace)
log_debug(f"Found {len(self._runs)} runs.")
for run in self._runs:
run[self.eval_key] = eval_config
self._fill_runs_from_default(self._runs)
self._fill_defaults()
def runs(self) -> Iterator[Run]:
"""
Creates and initializes runs from parsed run config.
"""
for n, _ in enumerate(self._runs, start=1):
yield self._make_run(n)
def prepare_data(self):
prepared = set()
for n, t in enumerate(self._runs, start=1):
name = t["task"]["name"]
if name not in prepared:
run = self._make_run(n)
log_info(f"Setting up data for {run.task_key} '{run.task.name}'...")
run.get_datamodule().prepare_data()
log_info("... finished.")
prepared.add(name)
def plot(self, save: bool = True) -> list[Figure]:
run = next(self.runs())
if self._block_plotting or not run.engine.plot:
return []
config = run.evaluation
set_plotstyle(config)
figs = []
for mode in config.checkpoints:
df = self.dataframe_from_runs(mode)
if config.plot.single_file:
fig, dfs = self.plot_one_fig(df, config)
if save:
self.save_one_plot(fig, dfs, config, mode)
figs.append(fig)
else:
# TODO: option to split into multiple files
raise NotImplementedError("evaluation.plot.single_file=False is not implemented yet.")
return figs
def plot_one_fig(self, df: DataFrame, config: EvalConfig):
if config.column_split_key is None:
dfs = [df]
else:
groups = df.groupby(config.column_split_key)
order = some(config.column_split_order, default=map(lambda x: x[0], sorted(groups)))
dfs: list[DataFrame] = [groups.get_group(group_name) for group_name in order]
fig, _ = create_figure(dfs, config)
return fig, dfs
def save_one_plot(self, fig, dfs: list[DataFrame], config: EvalConfig, mode: Literal["last", "best"]):
output_file_path = get_output_file_path(dfs, config, suffix=mode)
save_files(fig, dfs, output_file_path, config)
def dataframe_from_runs(self, mode: Literal["last", "best"]) -> DataFrame:
dfs: list[DataFrame] = []
for run in self.runs():
df = json_normalize(run.get_config())
if mode == "last":
result_file = run.run_dir / run.evaluation.experiment_files.last_model
elif mode == "best":
result_file = run.run_dir / run.evaluation.experiment_files.best_model
else:
raise ValueError(f"mode {mode} not supported")
if not result_file.is_file():
log_warn(f"result file {result_file} not found, skipping this hyperparameter setting")
continue
metric = run.evaluation.plot.metric
with open(result_file, "r", encoding="utf8") as f:
content = json.load(f)
if metric in content[0]:
df.at[0, metric] = content[0][metric]
else:
log_warn(f"could not find value for {metric} in json, skipping this hyperparameter setting")
continue
dfs.append(df)
if len(dfs) == 0:
raise ValueError("no dataframes found, check your config")
return concat(dfs, sort=False)
def _make_run(self, n: int) -> Run:
"""
n: number of the run, starting from 1
setup: download and prepare data
"""
i = n - 1
return Run(
self._runs[i],
self._defaults[i],
self.task_key,
self.optimizer_key,
self.engine_key,
self.eval_key,
self.identifier_key
)
def _named_dicts_to_list(self, searchspace: dict[str, Any], keys: list[str], valid_options: list[list[str]]):
assert len(keys) == len(valid_options)
for key, opts in zip(keys, valid_options):
if key not in searchspace:
continue
if isinstance(searchspace[key], dict) and all(name in opts for name in searchspace[key]):
searchspace[key] = [cfg | {self.identifier_key: name} for name, cfg in searchspace[key].items()]
def _fill_defaults(self):
self._defaults = []
for run in self._runs:
default_cfg = {
k: {self.identifier_key: run[k][self.identifier_key]}
for k in [self.task_key, self.optimizer_key]
}
self._defaults.append(default_cfg)
self._fill_runs_from_default(self._defaults)
def _fill_runs_from_default(self, runs: list[dict[str, Any]]):
for i, _ in enumerate(runs):
# order from higher to lower in hierarchy
runs[i] = self._fill_named_from_default(runs[i], self.task_key, task_path)
runs[i] = self._fill_named_from_default(runs[i], self.optimizer_key, optimizer_path)
runs[i] = self._fill_unnamed_from_default(runs[i], lr_schedulers_path)
runs[i] = self._fill_unnamed_from_default(runs[i], engine_path)
runs[i] = self._fill_unnamed_from_default(runs[i], evaluation_path)
def _fill_unnamed_from_default(self, experiment: dict[str, Any], unnamed_root: Callable) -> dict[str, Any]:
default_path: Path = unnamed_root() / self.default_file_name
default_config = self.parser.parse_yaml(default_path)
self.parser.merge_dicts_hierarchical(default_config, experiment)
return default_config
def _fill_named_from_default(self, experiment: dict[str, Any], key: str, named_root: Callable) -> dict[str, Any]:
self._argcheck_named(experiment, key, self.identifier_key)
named = experiment[key]
if isinstance(named, dict):
named = named[self.identifier_key]
else:
experiment[key] = {self.identifier_key: named}
default_path: Path = named_root(named) / self.default_file_name
default_config = self.parser.parse_yaml(default_path)
self.parser.merge_dicts_hierarchical(default_config, experiment)
return default_config
def _argcheck_named(self, experiment: dict[str, Any], key: str, identifier: str):
assert key in experiment, f"You did not provide any {key}."
assert isinstance(experiment[key], str) or identifier in experiment[key], \
f"Unknown {key}, either specify only a string or provide a key '{identifier}'"

View file

@ -0,0 +1,32 @@
from typing import Any
def unique(xs: list) -> list:
"""Returns deduplicated list"""
res = []
for x in xs:
if x not in res:
res.append(x)
return res
def grid_search(d: dict[str, Any]) -> list[dict[str, Any]]:
ret = []
if isinstance(d, dict):
if len(d) == 0:
return [dict()]
copy = d.copy()
k, v = copy.popitem()
configs = unique(grid_search(v))
rest = grid_search(copy)
for r in rest:
for config in configs:
ret.append(r | {k: config})
elif isinstance(d, list):
for v in d:
configs = grid_search(v)
for config in configs:
ret.append(config)
else:
ret.append(d)
return ret

View file

@ -0,0 +1,226 @@
from dataclasses import dataclass, field
from typing import Any, Callable, Iterable, Optional
from torch import nn
from torch.nn import Module
from torch.nn.parameter import Parameter
from pytorch_fob.engine.utils import some, log_warn
@dataclass
class ParameterGroup():
named_parameters: dict[str, Parameter]
lr_multiplier: Optional[float] = field(default=None)
weight_decay_multiplier: Optional[float] = field(default=None)
optimizer_kwargs: dict[str, Any] = field(default_factory=dict)
def __and__(self, other) -> "ParameterGroup":
assert isinstance(other, ParameterGroup)
n1 = set(self.named_parameters.keys())
n2 = set(other.named_parameters.keys())
all_params = self.named_parameters | other.named_parameters
n12 = n1 & n2
new_params = {n: all_params[n] for n in n12}
return ParameterGroup(
named_parameters=new_params,
lr_multiplier=some(other.lr_multiplier, default=self.lr_multiplier),
weight_decay_multiplier=some(other.weight_decay_multiplier, default=self.weight_decay_multiplier),
optimizer_kwargs=self.optimizer_kwargs | other.optimizer_kwargs
)
def __len__(self) -> int:
return len(self.named_parameters)
def __bool__(self) -> bool:
return not self.empty()
def empty(self) -> bool:
return len(self.named_parameters) == 0
def to_optimizer_dict(
self,
lr: Optional[float] = None,
weight_decay: Optional[float] = None
) -> dict[str, list[Parameter] | Any]:
names = sorted(self.named_parameters)
d = {
"params": [self.named_parameters[n] for n in names],
"names": names,
**self.optimizer_kwargs
}
if lr is not None:
d["lr"] = self.lr_multiplier * lr if self.lr_multiplier is not None else lr
if weight_decay is not None:
d["weight_decay"] = self.weight_decay_multiplier * weight_decay \
if self.weight_decay_multiplier is not None else weight_decay
return d
class GroupedModel(Module):
"""
Wrapper around a nn.Module to allow specifying different optimizer settings for different parameters.
To use this feature for your task, inherit from this class and override the `parameter_groups` method.
Then simply wrap your model before passing it to the `__init__` method of the `TaskModel` superclass.
"""
def __init__(self, model: Module) -> None:
super().__init__()
self.model = model
def forward(self, *args, **kwargs):
return self.model.forward(*args, **kwargs)
def parameter_groups(self) -> list[ParameterGroup]:
return wd_group_named_parameters(self.model)
def grouped_parameters(
self,
lr: Optional[float] = None,
weight_decay: Optional[float] = None
) -> list[dict[str, list[Parameter] | Any]]:
return [pg.to_optimizer_dict(lr, weight_decay) for pg in self.parameter_groups()]
def merge_parameter_splits(split1: list[ParameterGroup], split2: list[ParameterGroup]) -> list[ParameterGroup]:
"""
Merge two lists of ParameterGroup objects into a single list.
Assumes that both input lists partition the parameters.
"""
groups = []
for pg1 in split1:
for pg2 in split2:
pg12 = pg1 & pg2
if not pg12.empty():
groups.append(pg12)
return groups
def group_named_parameters(
model: Module,
g1_conds: Iterable[Callable] = (lambda *_: True,),
g2_conds: Iterable[Callable] = (lambda *_: True,),
special_conds: Iterable[Callable] = tuple(),
ignore_conds: Iterable[Callable] = tuple(),
g1_kwargs: Optional[dict[str, Any]] = None,
g2_kwargs: Optional[dict[str, Any]] = None,
debug: bool = False
) -> list[ParameterGroup]:
"""
Group named parameters based on specified conditions and return a list of ParameterGroup objects.
Args:
model (Module): The neural network model.
g1_conds (Iterable[Callable]): Conditions for selecting parameters for group 1.
g2_conds (Iterable[Callable]): Conditions for selecting parameters for group 2.
special_conds (Iterable[Callable]): Conditions for selecting special parameters that should not be grouped.
ignore_conds (Iterable[Callable]): Conditions for ignoring parameters (e.g. if they occur in submodules).
g1_kwargs (Optional[dict[str, Any]]): Additional keyword arguments for constructor of group 1.
g2_kwargs (Optional[dict[str, Any]]): Additional keyword arguments for constructor of group 2.
Returns:
List[ParameterGroup]: A list of ParameterGroup objects containing named parameters.
"""
g1_kwargs = g1_kwargs if g1_kwargs is not None else {}
g2_kwargs = g2_kwargs if g2_kwargs is not None else {}
s1 = set()
s2 = set()
special = set()
param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = f"{mn}.{pn}" if mn else pn # full param name
if not p.requires_grad or fpn not in param_dict:
continue # frozen weights
elif any(c(m, p, fpn) for c in ignore_conds):
continue
elif any(c(m, p, fpn) for c in special_conds):
special.add(fpn)
elif any(c(m, p, fpn) for c in g1_conds):
s1.add(fpn)
elif any(c(m, p, fpn) for c in g2_conds):
s2.add(fpn)
elif debug:
log_warn("group_named_parameters: Not using any rule for ", fpn, " in ", type(m))
s1 |= (param_dict.keys() - s2 - special)
# validate that we considered every parameter
inter_params = s1 & s2
union_params = s1 | s2
assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both s1/s2 sets!"
assert len(
param_dict.keys() - special - union_params) == 0, \
f"parameters {str(param_dict.keys() - union_params)} \
were not separated into either s1/s2 set!"
if not s2:
param_groups = [ParameterGroup(
named_parameters=dict(zip(sorted(union_params), (param_dict[pn] for pn in sorted(union_params))))
)]
else:
param_groups = [
ParameterGroup(
named_parameters=dict(zip(sorted(s1), (param_dict[pn] for pn in sorted(s1)))),
**g1_kwargs
),
ParameterGroup(
named_parameters=dict(zip(sorted(s2), (param_dict[pn] for pn in sorted(s2)))),
**g2_kwargs
),
]
return param_groups
def wd_group_named_parameters(model: Module) -> list[ParameterGroup]:
whitelist_weight_modules = (nn.Linear, nn.modules.conv._ConvNd) # pylint: disable=protected-access # noqa
blacklist_weight_modules = (nn.modules.batchnorm._NormBase, # pylint: disable=protected-access # noqa
nn.GroupNorm, nn.LayerNorm,
nn.LocalResponseNorm,
nn.Embedding)
ignore_modules = (nn.Sequential,)
apply_decay_conds = [lambda m, _, pn: pn.endswith('weight') and isinstance(m, whitelist_weight_modules)]
apply_no_decay_conds = [lambda m, _, pn: pn.endswith('bias') or isinstance(m, blacklist_weight_modules)]
special_conds = [lambda m, p, pn: hasattr(p, '_optim')]
ignore_conds = [lambda m, p, pn: isinstance(m, ignore_modules)]
return group_named_parameters(
model,
g1_conds=apply_decay_conds,
g2_conds=apply_no_decay_conds,
special_conds=special_conds,
ignore_conds=ignore_conds,
g2_kwargs={'weight_decay_multiplier': 0.0}
)
def resolve_parameter_dicts(dict1: dict[str, Any], dict2: dict[str, Any]) -> list[dict[str, Any]]:
p1, p2 = dict1["params"], dict2["params"]
n1, n2 = set(dict1["names"]), set(dict2["names"])
n_to_p1 = dict(zip(dict1["names"], dict1["params"]))
n_to_p2 = dict(zip(dict2["names"], dict2["params"]))
assert len(n1) == len(p1)
assert len(n2) == len(p2)
kwarg1 = {k: v for k, v in dict1.items() if k not in ["params", "names"]}
kwarg2 = {k: v for k, v in dict2.items() if k not in ["params", "names"]}
n1_and_n2 = n1 & n2
n1_no_n2 = n1 - n2
n2_no_n1 = n2 - n1
assert n1_and_n2 | n1_no_n2 | n2_no_n1 == n1 | n2
outdict1 = {"params": [n_to_p1[n] for n in sorted(n1_no_n2)],
"names": sorted(n1_no_n2), **kwarg1}
outdict2 = {"params": [n_to_p2[n] for n in sorted(n2_no_n1)],
"names": sorted(n2_no_n1), **kwarg2}
# kwarg2 takes precedence if an arg is present in both dicts:
outdict12 = {"params": [{**n_to_p1, **n_to_p2}[n] for n in sorted(n1_and_n2)],
"names": sorted(n1_and_n2), **kwarg1, **kwarg2}
return [outdict1, outdict2, outdict12]
def intersect_parameter_dicts(dict1: dict[str, Any], dict2: dict[str, Any]) -> Optional[dict[str, Any]]:
d = resolve_parameter_dicts(dict1, dict2)[2]
return d if len(d["params"]) > 0 else None
def merge_parameter_dicts(dict1: dict[str, Any], dict2: dict[str, Any]) -> list[dict[str, Any]]:
d = resolve_parameter_dicts(dict1, dict2)
return list(filter(lambda x: len(x["params"]) > 0, d))

View file

@ -0,0 +1,65 @@
from pathlib import Path
from typing import Any, Iterable, Optional
import re
import yaml
class YAMLParser():
def __init__(self) -> None:
pass
def parse_yaml(self, file: Path) -> Any:
"""
Opens and parses a YAML file.
"""
with open(file, "r", encoding="utf8") as f:
return yaml.safe_load(f)
def parse_yamls_and_extra_args(self,
default_yaml: Path,
custom_yaml: Optional[Path],
additional_args: Iterable[str] = tuple()
) -> dict:
"""assumes that there is a dict in the yaml"""
config_to_use = self.parse_yaml(default_yaml)
if custom_yaml is not None:
user_yaml = self.parse_yaml(custom_yaml)
# merge in place
self.merge_dicts_hierarchical(lo=config_to_use, hi=user_yaml)
self.parse_args_into_searchspace(config_to_use, additional_args)
return config_to_use
def parse_args_into_searchspace(self, searchspace: dict[str, Any], args: Iterable[str]):
"""
Overwrites args given in the form of 'this.that=something'. Also supports lists: 'this.that[0]=something'
"""
for arg in args:
self._parse_arg_into_searchspace(searchspace, arg)
def _parse_arg_into_searchspace(self, searchspace: dict[str, Any], arg: str):
keys, value = arg.split("=")
keys = keys.split(".")
keys_with_list_indices = []
for key in keys:
match = re.search(r"^(.*?)\[(\-?\d+)\]$", key)
if match:
keys_with_list_indices.append(match.group(1))
keys_with_list_indices.append(int(match.group(2)))
else:
keys_with_list_indices.append(key)
target = searchspace
for key in keys_with_list_indices[:-1]:
if isinstance(target, dict) and key not in target:
target[key] = {}
target = target[key]
target[keys_with_list_indices[-1]] = yaml.safe_load(value)
def merge_dicts_hierarchical(self, lo: dict, hi: dict):
"""
Overwrites values in `lo` with values from `hi` if they are present in both/
"""
for k, v in hi.items():
if isinstance(v, dict) and isinstance(lo.get(k, None), dict):
self.merge_dicts_hierarchical(lo[k], v)
else:
lo[k] = v

View file

@ -0,0 +1,298 @@
import hashlib
from pathlib import Path
import time
from typing import Any, Optional
from lightning import Callback, LightningDataModule, LightningModule, Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import Logger, TensorBoardLogger, CSVLogger
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
import torch
import yaml
from pytorch_fob.engine.callbacks import LogTrainingStats, OptimizerTime, PrintEpochWithTime, RestrictTrainEpochs
from pytorch_fob.engine.configs import EngineConfig, EvalConfig, OptimizerConfig, TaskConfig
from pytorch_fob.engine.utils import AttributeDict, EndlessList, calculate_steps, concatenate_dict_keys, convert_type_inside_dict, dict_differences, findfirst, path_to_str_inside_dict, precision_with_fallback, seconds_to_str, trainer_strategy, write_results, log_warn, log_info
from pytorch_fob.optimizers.optimizers import Optimizer
from pytorch_fob.tasks.tasks import TaskDataModule, TaskModel, import_task
class Run():
def __init__(
self,
config: dict[str, Any],
default_config: dict[str, Any],
task_key: str,
optimizer_key: str,
engine_key: str,
eval_key: str,
identifier_key: str
) -> None:
"""
setup: download and prepare data before creating the Run
"""
self._config = config
self._default_config = default_config
self.task_key = task_key
self.optimizer_key = optimizer_key
self.engine_key = engine_key
self.eval_key = eval_key
self.identifier_key = identifier_key
self._generate_configs()
self._set_outpath()
self._callbacks = AttributeDict({})
def start(self) -> dict[str, _EVALUATE_OUTPUT]:
self.run_dir.mkdir(parents=True, exist_ok=True)
self.export_config()
scores: dict[str, _EVALUATE_OUTPUT] = {}
if any([self.engine.train, self.engine.test]):
self._ensure_resume_path()
self.ensure_max_steps()
torch.set_float32_matmul_precision('high')
seed_everything(self.engine.seed, workers=True)
model, data_module = self.get_task()
if self.engine.train:
trainer = self.get_trainer()
self._train(trainer, model, data_module)
scores["mean_optimizer_time_ms"] = self._callbacks["optimizer_time"].total_mean_optimizer_step_time_ms
if self.engine.validate:
scores["validation"] = self._validate(trainer, model, data_module)
if self.engine.test:
tester = self.get_tester()
if self.engine.train: # no need to load last checkpoint, model is already loaded
ckpt = None
elif self.engine.resume is not None:
ckpt=self.engine.resume
else:
log_warn(
"No last checkpoint found, evaluating untrained model. " + \
"If this is unexpected, try to set 'engine.resume=true'."
)
ckpt = None
scores["test_final"] = self._test(tester, model, data_module, ckpt=ckpt) # type: ignore (see ensure_resume_path)
best_path = self.get_best_checkpoint()
if best_path is not None:
scores["test_best"] = self._test(tester, model, data_module, Path(best_path))
else:
log_info("No best checkpoint found, skipping test.")
write_results(scores, self.run_dir / "scores.json")
return scores
def _train(self, trainer: Trainer, model: LightningModule, data_module: LightningDataModule):
start_time = time.time()
if self.engine.accelerator == "gpu" and torch.cuda.is_available():
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=(self.engine.optimize_memory or not self.engine.deterministic)
):
trainer.fit(model, datamodule=data_module, ckpt_path=self.engine.resume) # type: ignore
else:
trainer.fit(model, datamodule=data_module, ckpt_path=self.engine.resume) # type: ignore
end_time = time.time()
train_time = int(end_time - start_time)
log_info(f"Finished training in {seconds_to_str(train_time)}.")
# Write train_time.txt
train_time_path = self.run_dir / "train_time.txt"
with open(train_time_path, "w") as f:
f.write(str(train_time) + "\n")
def _validate(self, trainer: Trainer, model: LightningModule, data_module: LightningDataModule) -> _EVALUATE_OUTPUT:
score = trainer.validate(model, datamodule=data_module)
return score
def _test(self, tester: Trainer, model: LightningModule, data_module: LightningDataModule, ckpt: Optional[Path] = None) -> _EVALUATE_OUTPUT:
ckpt_path = self.engine.resume if ckpt is None else ckpt
mode = "final" if ckpt_path is None or ckpt_path.stem.startswith("last") else "best" # type: ignore
log_info(f"Testing {mode} checkpoint...")
score = tester.test(model, datamodule=data_module, ckpt_path=ckpt_path) # type: ignore
write_results(score, self.run_dir / f"results_{mode}_model.json")
return score
def export_config(self):
with open(self.run_dir / "config.yaml", "w", encoding="utf8") as f:
d = path_to_str_inside_dict(self._config)
d = convert_type_inside_dict(d, EndlessList, list)
yaml.safe_dump(d, f)
def export_config_dict(self) -> dict[str, Any]:
d = path_to_str_inside_dict(self._config)
d = convert_type_inside_dict(d, EndlessList, list)
return d
def get_config(self) -> AttributeDict:
return AttributeDict(self._config)
def get_optimizer(self) -> Optimizer:
return Optimizer(self.optimizer)
def get_task(self) -> tuple[TaskModel, TaskDataModule]:
task_module = import_task(self.task.name)
return task_module.get_task(self.get_optimizer(), self.task)
def get_datamodule(self) -> TaskDataModule:
task_module = import_task(self.task.name)
return task_module.get_datamodule(self.task)
def get_callbacks(self) -> list[Callback]:
if len(self._callbacks) < 1:
self._init_callbacks()
return list(self._callbacks.values())
def get_loggers(self) -> list[Logger]:
return [
TensorBoardLogger(
save_dir=self.run_dir,
name="tb_logs"
),
CSVLogger(
save_dir=self.run_dir,
name="csv_logs"
)
]
def get_trainer(self) -> Trainer:
return Trainer(
max_steps=self.engine.max_steps,
logger=self.get_loggers(),
callbacks=self.get_callbacks(),
devices=self.engine.devices,
strategy=trainer_strategy(self.engine.devices),
enable_progress_bar=(not self.engine.silent),
deterministic=self.engine.deterministic,
detect_anomaly=self.engine.detect_anomaly,
gradient_clip_val=self.engine.gradient_clip_val,
gradient_clip_algorithm=self.engine.gradient_clip_alg,
precision=precision_with_fallback(self.engine.precision), # type: ignore
accelerator=self.engine.accelerator,
log_every_n_steps=self.engine.logging_inteval
)
def get_tester(self) -> Trainer:
return Trainer(
devices=1,
logger=False,
enable_progress_bar=(not self.engine.silent),
deterministic=self.engine.deterministic,
precision=precision_with_fallback(self.engine.precision), # type: ignore
accelerator=self.engine.accelerator
)
def get_best_checkpoint(self) -> Optional[Path]:
model_checkpoint = self._callbacks.get("best_model_checkpoint", None)
if model_checkpoint is not None:
model_checkpoint = Path(model_checkpoint.best_model_path)
model_checkpoint = model_checkpoint if not model_checkpoint.is_dir() else None
if model_checkpoint is None:
available_checkpoints = self.get_available_checkpoints()
model_checkpoint = findfirst(lambda x: x.stem.startswith("best"), available_checkpoints)
return model_checkpoint
def get_available_checkpoints(self) -> list[Path]:
if self.checkpoint_dir.exists():
return list(filter(lambda x: x.suffix == ".ckpt", self.checkpoint_dir.iterdir()))
return []
def ensure_max_steps(self):
"""
Ensures that `self.task.max_steps` is calculated and set correctly.
"""
if self.task.max_steps is None:
max_steps = self._calc_max_steps()
self._config[self.task_key]["max_steps"] = max_steps
if self._default_config[self.task_key]["max_steps"] is None:
self._default_config[self.task_key]["max_steps"] = max_steps
self._generate_configs()
log_info(f"'max_steps' not set explicitly, using {max_steps=} (calculated from " +
f"max_epochs={self.task.max_epochs}, batch_size={self.task.batch_size}, devices={self.engine.devices})")
def _ensure_resume_path(self):
"""
Ensures that `self.engine.resume` is either a valid Path or None.
"""
if isinstance(self.engine.resume, Path):
pass
elif isinstance(self.engine.resume, bool):
resume_path = None
if self.engine.resume:
available_checkpoints = self.get_available_checkpoints()
if len(available_checkpoints) < 1:
log_warn("engine.resume=True but no checkpoint was found. Starting run from scratch.")
else:
resume_path = findfirst(lambda x: x.stem == "last", available_checkpoints)
self._config[self.engine_key]["resume"] = resume_path
self._generate_configs()
else:
raise TypeError(f"Unsupportet type for 'resume', got {type(self.engine.resume)=}.")
def _calc_max_steps(self) -> int:
dm = self.get_datamodule()
dm.setup("fit")
train_samples = len(dm.data_train)
return calculate_steps(self.task.max_epochs, train_samples, self.engine.devices, self.task.batch_size)
def _init_callbacks(self):
self._callbacks["optimizer_time"] = OptimizerTime()
self._callbacks["best_model_checkpoint"] = ModelCheckpoint(
dirpath=self.checkpoint_dir,
filename="best-{epoch}-{step}",
monitor=self.task.target_metric,
mode=self.task.target_metric_mode
)
self._callbacks["model_checkpoint"] = ModelCheckpoint(
dirpath=self.checkpoint_dir,
enable_version_counter=False,
every_n_epochs=1,
save_last=True
)
if self.engine.early_stopping is not None:
self._callbacks["early_stopping"] = EarlyStopping(
monitor=self.engine.early_stopping_metric,
mode=self.task.target_metric_mode,
patience=self.engine.early_stopping,
check_finite=self.engine.check_finite,
log_rank_zero_only=True
)
self._callbacks["lr_monitor"] = LearningRateMonitor(
logging_interval=self.optimizer.lr_interval
)
if self.engine.log_extra:
self._callbacks["extra"] = LogTrainingStats(
log_every_n_steps=self.engine.logging_inteval,
**(self.engine.log_extra if isinstance(self.engine.log_extra, dict) else {})
)
self._callbacks["print_epoch"] = PrintEpochWithTime(self.engine.silent)
if self.engine.restrict_train_epochs is not None:
self._callbacks["restrict_train_epochs"] = RestrictTrainEpochs(self.engine.restrict_train_epochs)
# TODO: callback for logging time per step
def outpath_exclude_keys(self) -> list[str]:
return [
self.eval_key,
"output_dir_name"
]
def _set_outpath(self):
base: Path = self.engine.output_dir / self.task.output_dir_name / self.optimizer.output_dir_name
exclude_keys = self.outpath_exclude_keys()
exclude_keys += self.engine.outpath_irrelevant_engine_keys()
diffs = concatenate_dict_keys(dict_differences(self._config, self._default_config), exclude_keys=exclude_keys)
run_dir = ",".join(f"{k}={str(v)}" for k, v in sorted(diffs.items())) if diffs else "default"
if len(run_dir) > 254: # max file name length
hashdir = hashlib.md5(run_dir.encode()).hexdigest()
log_info(f"folder name {run_dir} is too long, using {hashdir} instead.")
run_dir = hashdir
self.run_dir = base / run_dir
self.checkpoint_dir = self.run_dir / "checkpoints"
def _generate_configs(self):
self.engine = EngineConfig(self._config, self.task_key, self.engine_key)
self.optimizer = OptimizerConfig(self._config, self.optimizer_key, self.task_key, self.identifier_key)
self.task = TaskConfig(self._config, self.task_key, self.engine_key, self.identifier_key)
self.evaluation = EvalConfig(
self._config,
eval_key=self.eval_key,
engine_key=self.engine_key,
ignore_keys=self.engine.outpath_irrelevant_engine_keys(prefix=f"{self.engine_key}.") + [f"{self.optimizer_key}.output_dir_name", f"{self.task_key}.output_dir_name"]
)

View file

@ -0,0 +1,171 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Iterable, Optional, Sequence
import traceback
import yaml
from pytorch_fob.engine.run import Run
from pytorch_fob.engine.slurm import Slurm
from pytorch_fob.engine.utils import log_info, log_warn, seconds_to_str, some, str_to_seconds
FOB_RUN_SCRIPT = "pytorch_fob.run_experiment"
FOB_EVAL_SCRIPT = "pytorch_fob.evaluate_experiment"
def argcheck_allequal_engine(
runs: list[Run],
keys: list[str],
reason: str = "'engine.run_scheduler=slurm_array'"
) -> None:
ok = True
first = runs[0]
for key in keys:
if not all(run.engine[key] == first.engine[key] for run in runs[1:]):
ok = False
break
if not ok:
req = ", ".join(map(lambda s: "engine." + s, keys))
raise ValueError(f"All runs must have the same values for {req} when using {reason}")
def export_experiment(run: Run, experiment: dict[str, Any]) -> Path:
run.run_dir.mkdir(parents=True, exist_ok=True)
outfile = run.run_dir / "experiment.yaml"
with open(outfile, "w", encoding="utf8") as f:
yaml.safe_dump(experiment, f)
return outfile
def process_args(args: dict[str, str], run: Run) -> None:
if "time" in args:
time = args["time"]
seconds = str_to_seconds(time) if isinstance(time, str) else time
args["time"] = seconds_to_str(int(run.engine.sbatch_time_factor * seconds))
if "gres" not in args and "gpus" not in args:
args["gres"] = f"gpu:{run.engine.devices}"
if not any(k.startswith("ntasks") for k in args):
args["ntasks-per-node"] = str(run.engine.devices)
if not any(k.startswith("cpus") for k in args):
args["cpus-per-task"] = str(run.engine.workers)
def wrap_template(template_path: Optional[Path], command: str, placeholder: str = "__FOB_COMMAND__") -> str:
if template_path is not None:
with open(template_path, "r", encoding="utf8") as f:
template = f.read()
if placeholder in template:
command = template.replace(placeholder, command)
else:
command = f"{template}\n{command}\n"
return command
def get_command(experiment_file: Path, index: Optional[str], plot: bool) -> str:
run_script = FOB_EVAL_SCRIPT if plot else FOB_RUN_SCRIPT
disable_plot = "" if plot else "engine.plot=false"
scheduler = "" if index is None else f"engine.run_scheduler=single:{index}"
return f"""srun python -m {run_script} {experiment_file} {scheduler} {disable_plot}"""
def get_job_name(run: Run) -> str:
return f"FOB-{run.task.name}-{run.optimizer.name}"
def get_slurm(job_name: str, args: dict[str, str], log_dir: Path, scripts_dir: Path) -> Slurm:
return Slurm(
job_name,
args,
log_dir=str(log_dir.resolve()),
scripts_dir=str(scripts_dir.resolve()),
bash_strict=False # TODO: maybe add arg or just remove 'nounset'
)
def run_slurm(
job_name: str,
command: str,
args: dict[str, str],
log_dir: Path,
save_sbatch_scripts: Optional[Path] = None,
dependencies: Sequence[int] = tuple(),
dependency_type: str = "afterok"
) -> Optional[int]:
if save_sbatch_scripts is None:
with TemporaryDirectory() as tmpdir:
s = get_slurm(job_name, args, log_dir, scripts_dir=Path(tmpdir).resolve())
return s.run(command, name_addition="", depends_on=dependencies, dependency_type=dependency_type)
else:
s = get_slurm(job_name, args, log_dir, scripts_dir=save_sbatch_scripts)
return s.run(command, name_addition="", depends_on=dependencies, dependency_type=dependency_type)
def run_plotting_job(
experiment_file: Path,
args: dict[str, str],
log_dir: Path,
dependencies: Sequence[int],
template: Optional[Path] = None
) -> None:
args["time"] = seconds_to_str(300) # 5 minutes should be plenty of time to plot
args.pop("array", None)
# no gpus needed for plotting
args.pop("gpus", None)
args.pop("gres", None)
# just one cpu per node for plotting
remove_keys = [k for k in args.keys() if k.startswith("ntasks") or k.startswith("cpus")]
for k in remove_keys:
args.pop(k)
args["nodes"] = "1"
args["ntasks-per-node"] = "1"
args["cpus-per-task"] = "2"
command = get_command(experiment_file, None, plot=True)
command = wrap_template(template, command)
run_slurm("FOB-plot", command, args, log_dir, dependencies=dependencies, dependency_type="afterany")
def slurm_array(runs: list[Run], experiment: dict[str, Any]) -> None:
equal_req = ["devices", "workers", "sbatch_args", "slurm_log_dir", "sbatch_script_template", "run_scheduler"]
argcheck_allequal_engine(runs, equal_req)
run = runs[0] # all runs have the same args
args = run.engine.sbatch_args
log_dir = some(run.engine.slurm_log_dir, default=run.engine.output_dir / "slurm_logs")
if "array" not in args:
args["array"] = f"1-{len(runs)}"
process_args(args, run)
experiment_file = [export_experiment(run, experiment).resolve() for run in runs][0]
command = get_command(experiment_file, "$SLURM_ARRAY_TASK_ID", plot=False)
command = wrap_template(run.engine.sbatch_script_template, command)
job_id = run_slurm(get_job_name(run), command, args, log_dir, save_sbatch_scripts=run.engine.save_sbatch_scripts)
if job_id is not None and run.engine.plot:
run_plotting_job(experiment_file, args, log_dir, [job_id], template=run.engine.sbatch_script_template)
def slurm_jobs(runs: list[Run], experiment: dict[str, Any]) -> list[int]:
job_ids = []
experiment_file = Path()
for i, run in enumerate(runs, start=1):
args = run.engine.sbatch_args
process_args(args, run)
log_dir = some(run.engine.slurm_log_dir, default=run.run_dir / "slurm_logs")
experiment_file = export_experiment(run, experiment).resolve()
command = get_command(experiment_file, str(i), plot=False)
command = wrap_template(run.engine.sbatch_script_template, command)
job_id = run_slurm(get_job_name(run), command, args, log_dir, save_sbatch_scripts=run.engine.save_sbatch_scripts)
if job_id is not None:
job_ids.append(job_id)
if len(job_ids) > 0 and any(map(lambda r: r.engine.plot, runs)):
equal_req = ["slurm_log_dir", "sbatch_script_template"]
argcheck_allequal_engine(runs, equal_req, reason="'engine.plot=true' with 'engine.run_scheduler=slurm_jobs'")
run_plotting_job(experiment_file, args, log_dir, job_ids, template=runs[0].engine.sbatch_script_template)
return job_ids
def sequential(runs: Iterable[Run], n_runs: int, experiment: dict[str, Any]):
for i, run in enumerate(runs, start=1):
log_info(f"Starting run {i}/{n_runs}.")
export_experiment(run, experiment)
try:
run.start()
except RuntimeError as _e: # detect_anomaly raises RuntimeError
t = traceback.format_exc()
log_warn(f"Run {i}/{n_runs} failed with {t}.")

View file

@ -0,0 +1,181 @@
"""
The MIT License (MIT)
Copyright (c) 2015 Brent Pedersen - Bioinformatics
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Adapted from https://github.com/brentp/slurmpy
"""
from __future__ import print_function
import sys
import os
import subprocess
import tempfile
import atexit
import hashlib
import datetime
from typing import Optional, Sequence
TMPL = """\
#!/bin/bash
#SBATCH -e {log_dir}/{name}.%J.err
#SBATCH -o {log_dir}/{name}.%J.out
#SBATCH -J {name}
{header}
{bash_setup}
__script__"""
def tmp(suffix=".sh"):
t = tempfile.mktemp(suffix=suffix)
atexit.register(os.unlink, t)
return t
class Slurm(object):
def __init__(self, name, slurm_kwargs=None, tmpl=None,
date_in_name=True, scripts_dir="slurm-scripts",
log_dir='logs', bash_strict=True):
if slurm_kwargs is None:
slurm_kwargs = {}
if tmpl is None:
tmpl = TMPL
self.log_dir = log_dir
self.bash_strict = bash_strict
header = []
if 'time' not in slurm_kwargs.keys():
slurm_kwargs['time'] = '84:00:00'
for k, v in slurm_kwargs.items():
if len(k) > 1:
k = "--" + k + "="
else:
k = "-" + k + " "
header.append(f"#SBATCH {k}{v}")
# add bash setup list to collect bash script config
bash_setup = []
if bash_strict:
bash_setup.append("set -eo pipefail -o nounset")
self.header = "\n".join(header)
self.bash_setup = "\n".join(bash_setup)
self.name = "".join(x for x in name.replace(
" ", "-") if x.isalnum() or x == "-")
self.tmpl = tmpl
self.slurm_kwargs = slurm_kwargs
if scripts_dir is not None:
self.scripts_dir = os.path.abspath(scripts_dir)
else:
self.scripts_dir = None
self.date_in_name = bool(date_in_name)
def __str__(self):
return self.tmpl.format(name=self.name, header=self.header,
log_dir=self.log_dir,
bash_setup=self.bash_setup)
def _tmpfile(self):
if self.scripts_dir is None:
return tmp()
else:
for _dir in [self.scripts_dir, self.log_dir]:
if not os.path.exists(_dir):
os.makedirs(_dir)
return f"{self.scripts_dir}/{self.name}.sh"
def run(self,
command: str,
name_addition: Optional[str] = None,
cmd_kwargs: Optional[dict[str, str]] = None,
_cmd: str = "sbatch",
tries: int = 1,
depends_on: Optional[Sequence[int]] = None,
dependency_type: str = "afterok"
) -> Optional[int]:
"""
command: a bash command that you want to run
name_addition: if not specified, the sha1 of the command to run
appended to job name. if it is "date", the yyyy-mm-dd
date will be added to the job name.
cmd_kwargs: dict of extra arguments to fill in command
(so command itself can be a template).
_cmd: submit command (change to "bash" for testing).
tries: try to run a job either this many times or until the first
success.
depends_on: job ids that this depends on before it is run
dependency_type: after, afterok, afterany, afternotok
"""
if name_addition is None:
name_addition = hashlib.sha1(command.encode("utf-8")).hexdigest()
if self.date_in_name:
name_addition += "-" + str(datetime.date.today())
name_addition = name_addition.strip(" -")
if cmd_kwargs is None:
cmd_kwargs = {}
n = self.name
self.name = self.name.strip(" -")
self.name += ("-" + name_addition.strip(" -"))
args = []
for k, v in cmd_kwargs.items():
args.append(f"export {k}={v}")
args = "\n".join(args)
tmpl = str(self).replace("__script__", args + "\n###\n" + command)
if depends_on is None or (len(depends_on) == 1 and depends_on[0] is None):
depends_on = []
with open(self._tmpfile(), "w", encoding="utf8") as sh:
sh.write(tmpl)
job_id = None
for itry in range(1, tries + 1):
args = [_cmd]
if depends_on is not None and len(depends_on) > 0:
dep = f"--dependency={dependency_type}:" + ":".join([str(x) for x in depends_on])
args.append(dep)
if itry > 1:
mid = f"--dependency=afternotok:{job_id}"
args.append(mid)
args.append(sh.name)
res = subprocess.check_output(args).strip()
print(res.decode(), file=sys.stderr)
self.name = n
if not res.startswith(b"Submitted batch"):
return None
j_id = int(res.split()[-1])
if itry == 1:
job_id = j_id
return job_id
if __name__ == "__main__":
import doctest
doctest.testmod()

View file

@ -0,0 +1,228 @@
import logging
from pathlib import Path
from typing import Any, Callable, Iterable, Optional, Type
import json
import math
import signal
import torch
from lightning_utilities.core.rank_zero import rank_zero_only, rank_zero_info, rank_zero_debug, log
def set_loglevel(level: str):
pytorch_logger = logging.getLogger("lightning.pytorch")
match level:
case "debug":
pytorch_logger.setLevel(logging.DEBUG)
case "info":
pytorch_logger.setLevel(logging.INFO)
case "warn":
pytorch_logger.setLevel(logging.WARNING)
case "error":
pytorch_logger.setLevel(logging.ERROR)
case "silent":
pytorch_logger.setLevel(logging.CRITICAL)
@rank_zero_only
def rank_zero_print(*args: Any, **kwargs: Any):
return print(*args, **kwargs)
@rank_zero_only
def log_warn(msg: str, *args: Any, prefix: str = "[FOB WARNING] ", **kwargs: Any):
return log.warning(f"{prefix}{msg}", *args, **kwargs)
def log_info(msg: str, *args: Any, prefix: str = "[FOB INFO] ", **kwargs: Any):
return rank_zero_info(f"{prefix}{msg}", *args, **kwargs)
def log_debug(msg: str, *args: Any, prefix: str = "[FOB DEBUG] ", **kwargs: Any):
return rank_zero_debug(f"{prefix}{msg}", *args, **kwargs)
def write_results(results, filepath: Path):
with open(filepath, "w", encoding="utf8") as f:
json.dump(results, f, indent=4)
print(f"Saved results into {filepath}.")
def wrap_list(x: Any) -> list[Any]:
if isinstance(x, list):
return x
return [x]
def calculate_steps(epochs: int, datapoints: int, devices: int, batch_size: int) -> int:
return math.ceil(datapoints / batch_size / devices) * epochs
def some(*args, default):
"""
returns the first argument that is not None or default.
"""
if len(args) < 1:
return default
first, *rest = args
if first is not None:
return first
return some(*rest, default=default)
def maybe_abspath(path: Optional[str | Path]) -> Optional[Path]:
if path is None:
return None
return Path(path).resolve()
def findfirst(f: Callable, xs: Iterable):
for x in xs:
if f(x):
return x
return None
def trainer_strategy(devices: int | list[int] | str) -> str:
if isinstance(devices, str):
return "auto"
ndevices = devices if isinstance(devices, int) else len(devices)
return "ddp" if ndevices > 1 else "auto"
def gpu_suited_for_compile():
if torch.cuda.is_available():
device_cap = torch.cuda.get_device_capability()
return device_cap in ((7, 0), (8, 0), (9, 0))
def precision_with_fallback(precision: str) -> str:
"""
Check if cuda supports bf16, if not using cuda or if not available return 16 instead of bf16
"""
if not torch.cuda.is_available():
log_warn("Warning: No CUDA available. Results can be different!")
return precision[2:]
if precision.startswith("bf") and not torch.cuda.is_bf16_supported():
log_warn("Warning: GPU does not support bfloat16. Results can be different!")
return precision[2:]
return precision
def str_to_seconds(s: str) -> int:
parts = s.split(":")
assert len(parts) == 3, f"Invalid time format: {s}. Use 'HH:MM:SS'."
return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2])
def seconds_to_str(total_seconds: int, sep: str = ":") -> str:
hours, rest = divmod(total_seconds, 3600)
minutes, seconds = divmod(rest, 60)
return sep.join(map(lambda x: str(x).zfill(2), [hours, minutes, seconds]))
def begin_timeout(delay=10, show_threads=False):
if show_threads:
import sys
import traceback
import threading
thread_names = {t.ident: t.name for t in threading.enumerate()}
for thread_id, frame in sys._current_frames().items():
print(f"Thread {thread_names.get(thread_id, thread_id)}:")
traceback.print_stack(frame)
print()
signal.alarm(delay) # Timeout after 10 seconds
def path_to_str_inside_dict(d: dict) -> dict:
return convert_type_inside_dict(d, Path, str)
def convert_type_inside_dict(d: dict, src: Type, tgt: Type) -> dict:
ret = {}
for k, v in d.items():
if isinstance(v, dict):
v = convert_type_inside_dict(v, src, tgt)
if isinstance(v, src):
ret[k] = tgt(v)
else:
ret[k] = v
return ret
def dict_differences(custom: dict[str, Any], default: dict[str, Any]) -> dict[str, Any]:
"""
Recursively returns a dictionary with the items in `custom` that are different or missing from `default`.
Example:
>>> dict_differences({"hi": 3, "bla": {"a": 2, "b": 2}}, {"hi": 2, "bla": {"a": 1, "b": 2}})
{'hi': 3, 'bla': {'a': 2}}
"""
diff: dict[str, Any] = {}
for key, value in custom.items():
if key in default:
default_value = default[key]
if default_value == value:
continue
if isinstance(value, dict) and isinstance(default_value, dict):
diff[key] = dict_differences(value, default_value)
else:
diff[key] = value
else:
diff[key] = value
return diff
def concatenate_dict_keys(
d: dict[str, Any],
parent_key: str = "",
sep: str = ".",
exclude_keys: Iterable[str] = tuple()
) -> dict[str, Any]:
"""
Example:
>>> concatenate_dict_keys({ "A": { "B": { "C": 1, "D": 2 }, "E": { "F": 3 } } })
{'A.B.C': 1, 'A.B.D': 2, 'A.E.F': 3}
>>> concatenate_dict_keys({ "A": { "B": { "C": 1, "D": 2 }, "E": { "F": 3 } } }, exclude_keys=["B"])
{'A.E.F': 3}
"""
result = {}
for k, v in d.items():
if k in exclude_keys:
continue
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
nested_result = concatenate_dict_keys(v, new_key, sep, exclude_keys)
result.update(nested_result)
else:
result[new_key] = v
return result
def sort_dict_recursively(d: dict) -> dict:
sorted_dict = {}
for k, v in sorted(d.items()):
if isinstance(v, dict):
sorted_dict[k] = sort_dict_recursively(v)
else:
sorted_dict[k] = v
return sorted_dict
class EndlessList(list):
"""
Returns first element if out of bounds. Otherwise same as list.
"""
def __getitem__(self, index):
if index >= len(self) and len(self) > 0:
return self[0]
return super().__getitem__(index)
class AttributeDict(dict):
def __getattribute__(self, key: str) -> Any:
try:
return super().__getattribute__(key)
except AttributeError:
pass
return super().__getitem__(key)