mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
Convert FOB submodule to regular folder
This commit is contained in:
parent
94f046ad40
commit
94825011a0
74 changed files with 4563 additions and 0 deletions
|
|
@ -0,0 +1,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
from pytorch_fob.engine.engine import Engine
|
||||
|
||||
|
||||
def repository_root() -> Path:
|
||||
return Path(__file__).resolve().parent.parent
|
||||
272
environments/optimizer/FOB/pytorch_fob/engine/callbacks.py
Normal file
272
environments/optimizer/FOB/pytorch_fob/engine/callbacks.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
import math
|
||||
import time
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import deepspeed
|
||||
import torch
|
||||
from lightning import Callback, LightningModule, Trainer
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only
|
||||
from torch.linalg import vector_norm
|
||||
|
||||
from pytorch_fob.engine.utils import log_debug, log_info, log_warn, seconds_to_str
|
||||
|
||||
|
||||
class RestrictTrainEpochs(Callback):
|
||||
"""Counts number of epochs since start of training and stops if max_epochs is reached."""
|
||||
|
||||
def __init__(self, max_epochs: int):
|
||||
super().__init__()
|
||||
self.max_epochs = max_epochs
|
||||
self.epochs = 0
|
||||
self.skip_first = False
|
||||
|
||||
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
|
||||
log_debug(f"Training for {self.max_epochs} epochs...")
|
||||
self.epochs = 0
|
||||
trainer.should_stop = False
|
||||
|
||||
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if self.skip_first:
|
||||
self.skip_first = False
|
||||
else:
|
||||
self.epochs += 1
|
||||
log_debug(f"Epoch {self.epochs}/{self.max_epochs}")
|
||||
# TODO: test for DDP, do we need 'trainer.strategy.reduce_boolean_decision'?
|
||||
if self.epochs >= self.max_epochs:
|
||||
log_debug(f"Stopping training after {self.epochs} epochs")
|
||||
trainer.should_stop = True
|
||||
|
||||
def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint):
|
||||
# checkpoint loads the model at the end of the epoch, so we do not count the first epoch
|
||||
self.skip_first = True
|
||||
|
||||
|
||||
class OptimizerTime(Callback):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.total_mean_optimizer_step_time_ms: float = 0.0
|
||||
self.total_epochs: int = 0
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
if len(pl_module.optimizer_times_ms) == 0:
|
||||
return
|
||||
epoch_mean = sum(pl_module.optimizer_times_ms) / len(pl_module.optimizer_times_ms)
|
||||
pl_module.log("mean_optimizer_step_time_ms", epoch_mean, on_step=False, on_epoch=True, sync_dist=True)
|
||||
|
||||
# Update the running mean
|
||||
self.total_epochs += 1
|
||||
self.total_mean_optimizer_step_time_ms = (
|
||||
(self.total_mean_optimizer_step_time_ms * (self.total_epochs - 1)) + epoch_mean
|
||||
) / self.total_epochs
|
||||
|
||||
# Reset the optimizer step times for the next epoch
|
||||
pl_module.optimizer_times_ms = [] # type: ignore
|
||||
|
||||
def state_dict(self) -> dict[str, float | int]:
|
||||
return {"running_mean": self.total_mean_optimizer_step_time_ms, "total_epochs": self.total_epochs}
|
||||
|
||||
def load_state_dict(self, state_dict: dict[str, float | int]):
|
||||
self.total_mean_optimizer_step_time_ms = state_dict["running_mean"]
|
||||
self.total_epochs = state_dict["total_epochs"] # type: ignore
|
||||
|
||||
|
||||
class PrintEpochWithTime(Callback):
|
||||
def __init__(self, active: bool = True):
|
||||
super().__init__()
|
||||
self.active: bool = active
|
||||
self.time: dict[str, Optional[float]]
|
||||
self.reset_time()
|
||||
|
||||
def reset_time(self):
|
||||
self.time = {"train_start": None, "val_start": None, "val_end": None}
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if self.active:
|
||||
self.time["train_start"] = time.time()
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
||||
# need to print here since train epoch ends after validation is done
|
||||
if self.active and all(v is not None for v in self.time.values()):
|
||||
max_epochs = pl_module.config.max_epochs
|
||||
train_time = math.ceil(time.time() - self.time["train_start"]) # type: ignore
|
||||
val_time = math.ceil(self.time["val_end"] - self.time["val_start"]) # type: ignore
|
||||
log_info(
|
||||
f"Finished training epoch {trainer.current_epoch + 1} of {max_epochs}. Time spent: training: {seconds_to_str(train_time - val_time)}, validation: {seconds_to_str(val_time)}, total: {seconds_to_str(train_time)}."
|
||||
)
|
||||
self.reset_time()
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if self.active:
|
||||
self.time["val_start"] = time.time()
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if self.active:
|
||||
self.time["val_end"] = time.time()
|
||||
|
||||
|
||||
def metric_fn(metric: str, v: torch.Tensor, override: Optional[float] = None) -> float:
|
||||
if override is not None:
|
||||
return override
|
||||
match metric:
|
||||
case "mean":
|
||||
return v.mean().item()
|
||||
case "sum":
|
||||
return v.sum().item()
|
||||
case "abs_mean":
|
||||
return v.abs().mean().item()
|
||||
case "std":
|
||||
return v.std().item()
|
||||
case "abs_std":
|
||||
return v.abs().std().item()
|
||||
case "min":
|
||||
return v.min().item()
|
||||
case "max":
|
||||
return v.max().item()
|
||||
case "l1":
|
||||
return vector_norm(v, ord=1).item()
|
||||
case "l2":
|
||||
return vector_norm(v, ord=2).item()
|
||||
case "sq_mean":
|
||||
return (v**2).mean().item()
|
||||
case "sq_sum":
|
||||
return (v**2).sum().item()
|
||||
case _:
|
||||
raise ValueError(f"unknown metric {metric}")
|
||||
|
||||
|
||||
def add_metrics_to_stats(
|
||||
stats: dict[str, float],
|
||||
prefix: str,
|
||||
name: str,
|
||||
v: torch.Tensor,
|
||||
metrics: Iterable[str],
|
||||
override: Optional[float] = None,
|
||||
):
|
||||
for metric in metrics:
|
||||
stats[f"{prefix}/{name}/{metric}"] = metric_fn(metric, v, override=override)
|
||||
|
||||
|
||||
class LogTrainingStats(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
log_gradient: bool = True,
|
||||
log_params: bool = True,
|
||||
log_quantiles: bool = False,
|
||||
log_momentum: bool = False,
|
||||
log_lrs: bool = True,
|
||||
log_every_n_steps: int = 50,
|
||||
change_log_interval_every_n_steps: Optional[int] = None,
|
||||
log_interval_factor: float = 2.0,
|
||||
min_log_interval: int = 1,
|
||||
max_log_interval: Optional[int] = None,
|
||||
metrics: Iterable[str] = ("mean", "abs_mean", "std", "abs_std", "min", "max", "l1", "l2", "sq_mean"),
|
||||
):
|
||||
super().__init__()
|
||||
self.log_gradient = log_gradient
|
||||
self.log_params = log_params
|
||||
self.log_quantiles = log_quantiles
|
||||
self.log_momentum = log_momentum
|
||||
self.log_lrs = log_lrs
|
||||
self.log_every_n_steps = log_every_n_steps
|
||||
self.change_log_interval_every_n_steps = change_log_interval_every_n_steps
|
||||
self.log_interval_factor = log_interval_factor
|
||||
self.min_log_interval = min_log_interval
|
||||
self.max_log_interval = max_log_interval
|
||||
self.metrics = metrics
|
||||
|
||||
def _check_and_adjust_log_interval(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if self.change_log_interval_every_n_steps is not None:
|
||||
if trainer.global_step > 0 and trainer.global_step % self.change_log_interval_every_n_steps == 0:
|
||||
self.log_every_n_steps = math.ceil(self.log_every_n_steps * self.log_interval_factor)
|
||||
self.log_every_n_steps = max(self.log_every_n_steps, self.min_log_interval)
|
||||
if self.max_log_interval is not None:
|
||||
self.log_every_n_steps = min(self.log_every_n_steps, self.max_log_interval)
|
||||
pl_module.log("logging_interval", self.log_every_n_steps)
|
||||
return trainer.global_step % self.log_every_n_steps == 0
|
||||
|
||||
@rank_zero_only
|
||||
def on_before_optimizer_step(self, trainer: Trainer, pl_module: LightningModule, optimizer: torch.optim.Optimizer):
|
||||
if self._check_and_adjust_log_interval(trainer, pl_module):
|
||||
stats = {}
|
||||
q = torch.arange(0.25, 1, 0.25).round(decimals=2).to(trainer.model.device)
|
||||
for param_group in optimizer.param_groups:
|
||||
for name, param in zip(param_group["names"], param_group["params"]):
|
||||
if self.log_params or self.log_lrs:
|
||||
v_detached = param.detach()
|
||||
|
||||
if self.log_params:
|
||||
if torch.isnan(v_detached).sum() > 0:
|
||||
log_warn(f"# NaN in param {name}")
|
||||
if torch.isinf(v_detached).sum() > 0:
|
||||
log_warn(f"# Inf in param {name}")
|
||||
|
||||
add_metrics_to_stats(stats, "param", name, v_detached, self.metrics)
|
||||
|
||||
if self.log_quantiles and v_detached.size().numel() < 10000000:
|
||||
deciles = torch.quantile(v_detached.float(), q, interpolation="linear")
|
||||
for q_idx, d_val in enumerate(deciles):
|
||||
stats[f"param/{name}/quantile-{q[q_idx]}"] = d_val.item()
|
||||
|
||||
if (self.log_gradient or self.log_lrs) and param.requires_grad:
|
||||
if trainer.num_devices > 1:
|
||||
grad_data = deepspeed.utils.safe_get_full_grad(param)
|
||||
else:
|
||||
grad_data = param.grad
|
||||
else:
|
||||
grad_data = None
|
||||
|
||||
if grad_data is not None:
|
||||
if torch.isnan(grad_data).sum() > 0:
|
||||
log_warn(f"# NaN in grad {name}")
|
||||
if torch.isinf(grad_data).sum() > 0:
|
||||
log_warn(f"# Inf in grad {name}")
|
||||
|
||||
if self.log_gradient:
|
||||
if torch.isnan(grad_data).sum() > 0 or torch.isinf(grad_data).sum() > 0:
|
||||
add_metrics_to_stats(stats, "grad", name, grad_data, self.metrics, override=-10.0)
|
||||
if self.log_quantiles and grad_data.size().numel() < 10000000:
|
||||
for q_idx, _ in enumerate(q):
|
||||
stats[f"param/{name}/quantile-{q[q_idx]}"] = -10
|
||||
|
||||
stats[f"grad/{name}/mean"] = grad_data.mean().item()
|
||||
if len(grad_data.shape) > 1 or grad_data.shape[0] > 1:
|
||||
add_metrics_to_stats(stats, "grad", name, grad_data, self.metrics)
|
||||
|
||||
if self.log_quantiles and grad_data.size().numel() < 10000000:
|
||||
deciles = torch.quantile(grad_data.float(), q, interpolation="linear")
|
||||
for q_idx, d_val in enumerate(deciles):
|
||||
stats[f"grad/{name}/quantile-{q[q_idx]}"] = d_val.item()
|
||||
|
||||
if self.log_lrs:
|
||||
grad_norm = vector_norm(grad_data)
|
||||
param_norm = vector_norm(v_detached)
|
||||
effective_lr = (grad_norm / param_norm).item() if param_norm != 0 else 0.0
|
||||
stats[f"param/{name}/effective_lr"] = effective_lr
|
||||
|
||||
if self.log_momentum or self.log_lrs:
|
||||
if param in optimizer.state:
|
||||
state = optimizer.state[param]
|
||||
else:
|
||||
state = {}
|
||||
|
||||
if self.log_momentum:
|
||||
if "exp_avg" in state:
|
||||
moment1 = state["exp_avg"]
|
||||
elif "momentum_buffer" in state:
|
||||
moment1 = state["momentum_buffer"]
|
||||
else:
|
||||
moment1 = None
|
||||
if moment1 is not None:
|
||||
add_metrics_to_stats(stats, "1st_order_momentum", name, moment1, self.metrics)
|
||||
if "exp_avg_sq" in state:
|
||||
add_metrics_to_stats(stats, "2nd_order_momentum", name, state["exp_avg_sq"], self.metrics)
|
||||
if self.log_lrs and "lr" in state:
|
||||
stats[f"param/{name}/lr"] = state["lr"].item()
|
||||
|
||||
if trainer.loggers is not None:
|
||||
for logger in trainer.loggers:
|
||||
logger.log_metrics(stats, step=trainer.global_step)
|
||||
156
environments/optimizer/FOB/pytorch_fob/engine/configs.py
Normal file
156
environments/optimizer/FOB/pytorch_fob/engine/configs.py
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
from .utils import AttributeDict, EndlessList, convert_type_inside_dict, maybe_abspath, some, wrap_list
|
||||
|
||||
|
||||
class BaseConfig(AttributeDict):
|
||||
def __init__(self, config: dict):
|
||||
super().__init__(convert_type_inside_dict(config, dict, AttributeDict))
|
||||
|
||||
|
||||
class NamedConfig(BaseConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
identifier_key: str = "name",
|
||||
outdir_key: str = "output_dir_name"
|
||||
) -> None:
|
||||
super().__init__(config)
|
||||
self.name = config[identifier_key]
|
||||
self.output_dir_name = config.get(outdir_key, self.name)
|
||||
|
||||
|
||||
class OptimizerConfig(NamedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
optimizer_key: str,
|
||||
task_key: str,
|
||||
identifier_key: str = "name",
|
||||
outdir_key: str = "output_dir_name"
|
||||
) -> None:
|
||||
cfg = dict(config[optimizer_key])
|
||||
self.lr_interval: Literal["step", "epoch"] = cfg.get("lr_interval", "step")
|
||||
self.max_steps: int = config[task_key].get("max_steps", None)
|
||||
self.max_epochs: int = config[task_key]["max_epochs"]
|
||||
cfg["max_steps"] = self.max_steps
|
||||
cfg["max_epochs"] = self.max_epochs
|
||||
super().__init__(cfg, identifier_key, outdir_key)
|
||||
|
||||
|
||||
class TaskConfig(NamedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
task_key: str,
|
||||
engine_key: str,
|
||||
identifier_key: str = "name",
|
||||
outdir_key: str = "output_dir_name"
|
||||
) -> None:
|
||||
cfg = dict(config[task_key])
|
||||
self.batch_size: int = cfg["batch_size"]
|
||||
self.data_dir = Path(config[engine_key]["data_dir"]).resolve()
|
||||
self.max_epochs: int = cfg["max_epochs"]
|
||||
self.max_steps: int = cfg.get("max_steps", None)
|
||||
self.target_metric: str = cfg["target_metric"]
|
||||
self.target_metric_mode: str = cfg["target_metric_mode"]
|
||||
self.workers = config[engine_key]["workers"]
|
||||
cfg["data_dir"] = self.data_dir
|
||||
cfg["workers"] = self.workers
|
||||
super().__init__(cfg, identifier_key, outdir_key)
|
||||
|
||||
|
||||
class EngineConfig(BaseConfig):
|
||||
def __init__(self, config: dict[str, Any], task_key: str, engine_key: str) -> None:
|
||||
cfg = dict(config[engine_key])
|
||||
self.accelerator = cfg["accelerator"]
|
||||
self.deterministic: bool | Literal["warn"] = cfg["deterministic"]
|
||||
self.data_dir = Path(cfg["data_dir"]).resolve()
|
||||
self.detect_anomaly: bool = cfg["detect_anomaly"]
|
||||
self.devices: int = some(cfg["devices"], default=1)
|
||||
self.early_stopping: Optional[int] = cfg["early_stopping"]
|
||||
self.early_stopping_metric: str = some(cfg["early_stopping_metric"], default=config[task_key]["target_metric"])
|
||||
self.gradient_clip_alg: str = cfg["gradient_clip_alg"]
|
||||
self.gradient_clip_val: Optional[float] = cfg["gradient_clip_val"]
|
||||
self.log_extra: bool | dict[str, bool] = cfg["log_extra"]
|
||||
self.logging_inteval: int = cfg["logging_interval"]
|
||||
self.max_steps: int = config[task_key].get("max_steps", None)
|
||||
self.optimize_memory: bool = cfg["optimize_memory"]
|
||||
self.output_dir = Path(cfg["output_dir"]).resolve()
|
||||
self.plot: bool = cfg["plot"]
|
||||
self.precision: str = cfg["precision"]
|
||||
self.restrict_train_epochs: Optional[int] = cfg["restrict_train_epochs"]
|
||||
_resume = cfg.get("resume", False)
|
||||
self.resume: Optional[Path] | bool = Path(_resume).resolve() if isinstance(_resume, str) else _resume
|
||||
self.run_scheduler: str = cfg["run_scheduler"]
|
||||
self.seed: int = cfg["seed"]
|
||||
self.seed_mode: str = cfg["seed_mode"]
|
||||
self.save_sbatch_scripts: Optional[Path] = maybe_abspath(cfg["save_sbatch_scripts"])
|
||||
self.sbatch_args: dict[str, str] = cfg["sbatch_args"]
|
||||
self.sbatch_script_template: Optional[Path] = maybe_abspath(cfg["sbatch_script_template"])
|
||||
self.sbatch_time_factor: float = cfg["sbatch_time_factor"]
|
||||
self.slurm_log_dir: Optional[Path] = maybe_abspath(cfg["slurm_log_dir"])
|
||||
self.silent: bool = cfg.get("silent", False)
|
||||
self.test: bool = cfg.get("test", True)
|
||||
self.train: bool = cfg.get("train", True)
|
||||
self.validate: bool = cfg.get("validate", False)
|
||||
self.workers: int = cfg["workers"]
|
||||
cfg["data_dir"] = self.data_dir
|
||||
cfg["devices"] = self.devices
|
||||
cfg["early_stopping_metric"] = self.early_stopping_metric
|
||||
cfg["max_steps"] = self.max_steps
|
||||
cfg["output_dir"] = self.output_dir
|
||||
cfg["resume"] = self.resume
|
||||
cfg["slurm_log_dir"] = self.slurm_log_dir
|
||||
cfg["save_sbatch_scripts"] = self.save_sbatch_scripts
|
||||
cfg["sbatch_script_template"] = self.sbatch_script_template
|
||||
super().__init__(cfg)
|
||||
|
||||
def outpath_relevant_engine_keys(self, prefix: str = "") -> list[str]:
|
||||
keys = [
|
||||
"accelerator",
|
||||
"deterministic",
|
||||
"detect_anomaly",
|
||||
"devices",
|
||||
"early_stopping",
|
||||
"gradient_clip_alg",
|
||||
"gradient_clip_val",
|
||||
"optimize_memory",
|
||||
"precision",
|
||||
"seed"
|
||||
]
|
||||
return [f"{prefix}{k}" for k in keys]
|
||||
|
||||
def outpath_irrelevant_engine_keys(self, prefix: str = "") -> list[str]:
|
||||
return [f"{prefix}{k}" for k in self.keys() if k not in self.outpath_relevant_engine_keys()]
|
||||
|
||||
|
||||
class EvalConfig(BaseConfig):
|
||||
def __init__(self, config: dict[str, Any], eval_key: str, engine_key: str, ignore_keys = None) -> None:
|
||||
cfg = dict(config[eval_key])
|
||||
self.experiment_files = AttributeDict(dict(
|
||||
best_model = "results_best_model.json",
|
||||
last_model = "results_final_model.json",
|
||||
config = "config.yaml"
|
||||
))
|
||||
self.output_types: list[str] = wrap_list(cfg["output_types"])
|
||||
experiment_dir = Path(config[engine_key]["output_dir"]).resolve()
|
||||
self.output_dir: Path = some(maybe_abspath(cfg["output_dir"]), default=experiment_dir / "plots")
|
||||
self.experiment_name: str = cfg["experiment_name"]
|
||||
self.verbose: bool = cfg.get("verbose", False)
|
||||
split = cfg.get("split_groups", False)
|
||||
self.split_groups: bool | list[str] = split if isinstance(split, bool) else wrap_list(split)
|
||||
self.checkpoints: list[Literal["last", "best"]] = wrap_list(cfg["checkpoints"])
|
||||
self.column_split_key: Optional[str] = cfg.get("column_split_key", None)
|
||||
self.column_split_order: Optional[list[str]] = cfg.get("column_split_order", None)
|
||||
self.ignore_keys: list[str] = some(ignore_keys, default=[])
|
||||
self.aggregate_groups: list[str] = wrap_list(cfg["aggregate_groups"])
|
||||
cfg["ignore_keys"] = self.ignore_keys
|
||||
cfg["output_types"] = self.output_types
|
||||
cfg["output_dir"] = self.output_dir
|
||||
cfg["aggregate_groups"] = self.aggregate_groups
|
||||
cfg["output_types"] = self.output_types
|
||||
cfg["plot"]["x_axis"] = EndlessList(wrap_list(cfg["plot"]["x_axis"]))
|
||||
cfg["plot"]["y_axis"] = EndlessList(wrap_list(cfg["plot"]["y_axis"]))
|
||||
cfg["split_groups"] = self.split_groups
|
||||
super().__init__(cfg)
|
||||
41
environments/optimizer/FOB/pytorch_fob/engine/default.yaml
Normal file
41
environments/optimizer/FOB/pytorch_fob/engine/default.yaml
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
engine:
|
||||
accelerator: gpu # Whether to train on cpu or gpu
|
||||
check_finite: true # Check if 'early_stopping_metric' is finite during training. Aborts training if not. Only active when 'early_stopping' is not null.
|
||||
data_dir: ./data # Where you want to store the training data
|
||||
deterministic: warn # 'warn' tries to use deterministic algorithms if possible, also accepts true or false.
|
||||
detect_anomaly: false # Lightning trainer argument with same name.
|
||||
devices: null # This is set by each task by default, but can be overridden
|
||||
early_stopping: null # The number of epochs to wait before stopping if no improvement is found. Set to null to disable.
|
||||
early_stopping_metric: null # Metric to use for early stopping. If null, uses 'task.target_metric'.
|
||||
gradient_clip_alg: norm # {value, norm} to disable gradient clipping: set 'gradient_clip_val' to null
|
||||
gradient_clip_val: null # DEFAULT: don't clip gradients, expects value in [0, 1]
|
||||
log_extra: false # Activate logging of gradients and more. Can be bool or a dict with the options supported by callback `LogTrainingStats` in `pytorch_fob/engine/callbacks.py`.
|
||||
logging_interval: 50 # Number of steps between each logging step.
|
||||
optimize_memory: false # Use nondeterministic, but memory-efficient algorithms for self-attention
|
||||
output_dir: ./experiments # Where you want to store the results
|
||||
plot: true # Whether to plot the results.
|
||||
precision: bf16-mixed # Floating precision of training, see https://lightning.ai/docs/pytorch/stable/common/precision_basic.html
|
||||
restrict_train_epochs: null # Only train for a specific number of epochs. Set to null to disable. The epochs set here are counted from start of training, so this works with 'resume'.
|
||||
resume: true # You can either pass the path to your checkpoint here or set to true, which loads the last checkpoint.
|
||||
run_scheduler: sequential # How to schedule the runs of the experiment. Supported values:
|
||||
# 'sequential': runs are performed sequentially
|
||||
# 'single:N' where N is the number of the run starting from 1.
|
||||
# 'slurm_array': runs are scheduled using a SLURM array job.
|
||||
# 'slurm_jobs': runs are scheduled using independent SLURM jobs
|
||||
save_sbatch_scripts: null # Path to directory where sbatch scripts will be saved. If null, sbatch scripts will not be saved.
|
||||
sbatch_time_factor: 1 # Time factor for SLURM. Multiplies all default times by this factor.
|
||||
sbatch_args: # Additional arguments to pass to sbatch. Only used if run_scheduler is 'slurm_array'.
|
||||
# ntasks-per-node and gres are set to 'devices' by default
|
||||
# cpus-per-task is set to 'workers' by default
|
||||
nodes: 1
|
||||
mem-per-cpu: 2gb
|
||||
time: 00:30:00 # Each task has their own default time (assumes A100 or similar gpu). Format: HH:MM:SS or seconds.
|
||||
sbatch_script_template: null # Path to template for the sbatch script. Script can contain placeholder '__FOB_COMMAND__'. Otherwise it will be executed before the experiment. 'sbatch_args' will be added to the beginning of the script.
|
||||
slurm_log_dir: null # Default: 'output_dir/slurm_logs' for run_scheduler 'slurm_array' and 'run_dir/slurm_logs' for run_scheduler 'slurm_jobs'
|
||||
seed: 42 # The seed to use for the experiment
|
||||
seed_mode: fixed # Currently only supports 'fixed'
|
||||
silent: false # whether to hide progress bars. Recommended when writing outputs to a log file.
|
||||
test: true # Whether to test the model.
|
||||
train: true # Whether to train the model.
|
||||
validate: false # Whether to validate the model after training (only useful if you are interested in the results, for example for HPO).
|
||||
workers: 16 # The number of processes to use for dataloading
|
||||
228
environments/optimizer/FOB/pytorch_fob/engine/engine.py
Normal file
228
environments/optimizer/FOB/pytorch_fob/engine/engine.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Iterable, Iterator, Literal, Optional
|
||||
from pathlib import Path
|
||||
from matplotlib.figure import Figure
|
||||
from pandas import DataFrame, concat, json_normalize
|
||||
from pytorch_fob.engine.configs import EvalConfig
|
||||
from pytorch_fob.engine.grid_search import grid_search
|
||||
from pytorch_fob.engine.parser import YAMLParser
|
||||
from pytorch_fob.engine.run import Run
|
||||
from pytorch_fob.engine.run_schedulers import sequential, slurm_array, slurm_jobs
|
||||
from pytorch_fob.engine.utils import log_debug, log_info, log_warn, some, sort_dict_recursively
|
||||
from pytorch_fob.evaluation import evaluation_path
|
||||
from pytorch_fob.evaluation.plot import create_figure, get_output_file_path, save_files, set_plotstyle
|
||||
from pytorch_fob.optimizers import lr_schedulers_path, optimizer_path, optimizer_names
|
||||
from pytorch_fob.tasks import task_path, task_names
|
||||
|
||||
|
||||
def engine_path() -> Path:
|
||||
return Path(__file__).resolve().parent
|
||||
|
||||
|
||||
class Engine():
|
||||
def __init__(self) -> None:
|
||||
self._runs = []
|
||||
self._defaults = []
|
||||
self._experiment = {}
|
||||
self._experiment_file = None
|
||||
self._block_plotting = False
|
||||
self.task_key = "task"
|
||||
self.optimizer_key = "optimizer"
|
||||
self.engine_key = "engine"
|
||||
self.eval_key = "evaluation"
|
||||
self.identifier_key = "name"
|
||||
self.default_file_name = "default.yaml"
|
||||
self.parser = YAMLParser()
|
||||
|
||||
def run_experiment(self) -> Optional[list[int]]:
|
||||
assert len(self._runs) > 0, "No runs in experiment, make sure to call 'parse_experiment' first."
|
||||
scheduler = self._runs[0][self.engine_key]["run_scheduler"]
|
||||
assert all(map(lambda x: x[self.engine_key]["run_scheduler"] == scheduler, self._runs)), \
|
||||
"You cannot perform gridsearch on 'run_scheduler'."
|
||||
if scheduler == "sequential":
|
||||
sequential(self.runs(), len(self._runs), self._experiment)
|
||||
elif scheduler.startswith("single"):
|
||||
n = int(scheduler.rsplit(":", 1)[-1])
|
||||
log_info(f"Starting run {n}/{len(self._runs)}.")
|
||||
run = self._make_run(n)
|
||||
run.start()
|
||||
elif scheduler == "slurm_array":
|
||||
self._block_plotting = True
|
||||
slurm_array(list(self.runs()), self._experiment)
|
||||
elif scheduler == "slurm_jobs":
|
||||
self._block_plotting = True
|
||||
return slurm_jobs(list(self.runs()), self._experiment)
|
||||
else:
|
||||
raise ValueError(f"Unsupported run_scheduler: {scheduler=}.")
|
||||
|
||||
def parse_experiment_from_file(self, file: Path, extra_args: Iterable[str] = tuple()):
|
||||
self._experiment_file = file.resolve()
|
||||
searchspace: dict[str, Any] = self.parser.parse_yaml(self._experiment_file)
|
||||
self.parse_experiment(searchspace, extra_args)
|
||||
|
||||
def parse_experiment(self, searchspace: dict[str, Any], extra_args: Iterable[str] = tuple()):
|
||||
self.parser.parse_args_into_searchspace(searchspace, extra_args)
|
||||
# normalize experiment
|
||||
self._named_dicts_to_list(
|
||||
searchspace,
|
||||
[self.optimizer_key, self.task_key],
|
||||
[optimizer_names(), task_names()]
|
||||
)
|
||||
searchspace = sort_dict_recursively(searchspace)
|
||||
self._experiment = deepcopy(searchspace)
|
||||
# exclude plotting from gridsearch
|
||||
if self.eval_key in searchspace:
|
||||
eval_config = searchspace.pop(self.eval_key)
|
||||
else:
|
||||
eval_config = {}
|
||||
log_debug("Performing gridsearch...")
|
||||
self._runs = grid_search(searchspace)
|
||||
log_debug(f"Found {len(self._runs)} runs.")
|
||||
for run in self._runs:
|
||||
run[self.eval_key] = eval_config
|
||||
self._fill_runs_from_default(self._runs)
|
||||
self._fill_defaults()
|
||||
|
||||
def runs(self) -> Iterator[Run]:
|
||||
"""
|
||||
Creates and initializes runs from parsed run config.
|
||||
"""
|
||||
for n, _ in enumerate(self._runs, start=1):
|
||||
yield self._make_run(n)
|
||||
|
||||
def prepare_data(self):
|
||||
prepared = set()
|
||||
for n, t in enumerate(self._runs, start=1):
|
||||
name = t["task"]["name"]
|
||||
if name not in prepared:
|
||||
run = self._make_run(n)
|
||||
log_info(f"Setting up data for {run.task_key} '{run.task.name}'...")
|
||||
run.get_datamodule().prepare_data()
|
||||
log_info("... finished.")
|
||||
prepared.add(name)
|
||||
|
||||
def plot(self, save: bool = True) -> list[Figure]:
|
||||
run = next(self.runs())
|
||||
if self._block_plotting or not run.engine.plot:
|
||||
return []
|
||||
config = run.evaluation
|
||||
set_plotstyle(config)
|
||||
figs = []
|
||||
for mode in config.checkpoints:
|
||||
df = self.dataframe_from_runs(mode)
|
||||
if config.plot.single_file:
|
||||
fig, dfs = self.plot_one_fig(df, config)
|
||||
if save:
|
||||
self.save_one_plot(fig, dfs, config, mode)
|
||||
figs.append(fig)
|
||||
else:
|
||||
# TODO: option to split into multiple files
|
||||
raise NotImplementedError("evaluation.plot.single_file=False is not implemented yet.")
|
||||
return figs
|
||||
|
||||
def plot_one_fig(self, df: DataFrame, config: EvalConfig):
|
||||
if config.column_split_key is None:
|
||||
dfs = [df]
|
||||
else:
|
||||
groups = df.groupby(config.column_split_key)
|
||||
order = some(config.column_split_order, default=map(lambda x: x[0], sorted(groups)))
|
||||
dfs: list[DataFrame] = [groups.get_group(group_name) for group_name in order]
|
||||
fig, _ = create_figure(dfs, config)
|
||||
return fig, dfs
|
||||
|
||||
def save_one_plot(self, fig, dfs: list[DataFrame], config: EvalConfig, mode: Literal["last", "best"]):
|
||||
output_file_path = get_output_file_path(dfs, config, suffix=mode)
|
||||
save_files(fig, dfs, output_file_path, config)
|
||||
|
||||
def dataframe_from_runs(self, mode: Literal["last", "best"]) -> DataFrame:
|
||||
dfs: list[DataFrame] = []
|
||||
for run in self.runs():
|
||||
df = json_normalize(run.get_config())
|
||||
if mode == "last":
|
||||
result_file = run.run_dir / run.evaluation.experiment_files.last_model
|
||||
elif mode == "best":
|
||||
result_file = run.run_dir / run.evaluation.experiment_files.best_model
|
||||
else:
|
||||
raise ValueError(f"mode {mode} not supported")
|
||||
if not result_file.is_file():
|
||||
log_warn(f"result file {result_file} not found, skipping this hyperparameter setting")
|
||||
continue
|
||||
metric = run.evaluation.plot.metric
|
||||
with open(result_file, "r", encoding="utf8") as f:
|
||||
content = json.load(f)
|
||||
if metric in content[0]:
|
||||
df.at[0, metric] = content[0][metric]
|
||||
else:
|
||||
log_warn(f"could not find value for {metric} in json, skipping this hyperparameter setting")
|
||||
continue
|
||||
dfs.append(df)
|
||||
if len(dfs) == 0:
|
||||
raise ValueError("no dataframes found, check your config")
|
||||
return concat(dfs, sort=False)
|
||||
|
||||
def _make_run(self, n: int) -> Run:
|
||||
"""
|
||||
n: number of the run, starting from 1
|
||||
setup: download and prepare data
|
||||
"""
|
||||
i = n - 1
|
||||
return Run(
|
||||
self._runs[i],
|
||||
self._defaults[i],
|
||||
self.task_key,
|
||||
self.optimizer_key,
|
||||
self.engine_key,
|
||||
self.eval_key,
|
||||
self.identifier_key
|
||||
)
|
||||
|
||||
def _named_dicts_to_list(self, searchspace: dict[str, Any], keys: list[str], valid_options: list[list[str]]):
|
||||
assert len(keys) == len(valid_options)
|
||||
for key, opts in zip(keys, valid_options):
|
||||
if key not in searchspace:
|
||||
continue
|
||||
if isinstance(searchspace[key], dict) and all(name in opts for name in searchspace[key]):
|
||||
searchspace[key] = [cfg | {self.identifier_key: name} for name, cfg in searchspace[key].items()]
|
||||
|
||||
def _fill_defaults(self):
|
||||
self._defaults = []
|
||||
for run in self._runs:
|
||||
default_cfg = {
|
||||
k: {self.identifier_key: run[k][self.identifier_key]}
|
||||
for k in [self.task_key, self.optimizer_key]
|
||||
}
|
||||
self._defaults.append(default_cfg)
|
||||
self._fill_runs_from_default(self._defaults)
|
||||
|
||||
def _fill_runs_from_default(self, runs: list[dict[str, Any]]):
|
||||
for i, _ in enumerate(runs):
|
||||
# order from higher to lower in hierarchy
|
||||
runs[i] = self._fill_named_from_default(runs[i], self.task_key, task_path)
|
||||
runs[i] = self._fill_named_from_default(runs[i], self.optimizer_key, optimizer_path)
|
||||
runs[i] = self._fill_unnamed_from_default(runs[i], lr_schedulers_path)
|
||||
runs[i] = self._fill_unnamed_from_default(runs[i], engine_path)
|
||||
runs[i] = self._fill_unnamed_from_default(runs[i], evaluation_path)
|
||||
|
||||
def _fill_unnamed_from_default(self, experiment: dict[str, Any], unnamed_root: Callable) -> dict[str, Any]:
|
||||
default_path: Path = unnamed_root() / self.default_file_name
|
||||
default_config = self.parser.parse_yaml(default_path)
|
||||
self.parser.merge_dicts_hierarchical(default_config, experiment)
|
||||
return default_config
|
||||
|
||||
def _fill_named_from_default(self, experiment: dict[str, Any], key: str, named_root: Callable) -> dict[str, Any]:
|
||||
self._argcheck_named(experiment, key, self.identifier_key)
|
||||
named = experiment[key]
|
||||
if isinstance(named, dict):
|
||||
named = named[self.identifier_key]
|
||||
else:
|
||||
experiment[key] = {self.identifier_key: named}
|
||||
default_path: Path = named_root(named) / self.default_file_name
|
||||
default_config = self.parser.parse_yaml(default_path)
|
||||
self.parser.merge_dicts_hierarchical(default_config, experiment)
|
||||
return default_config
|
||||
|
||||
def _argcheck_named(self, experiment: dict[str, Any], key: str, identifier: str):
|
||||
assert key in experiment, f"You did not provide any {key}."
|
||||
assert isinstance(experiment[key], str) or identifier in experiment[key], \
|
||||
f"Unknown {key}, either specify only a string or provide a key '{identifier}'"
|
||||
32
environments/optimizer/FOB/pytorch_fob/engine/grid_search.py
Normal file
32
environments/optimizer/FOB/pytorch_fob/engine/grid_search.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from typing import Any
|
||||
|
||||
|
||||
def unique(xs: list) -> list:
|
||||
"""Returns deduplicated list"""
|
||||
res = []
|
||||
for x in xs:
|
||||
if x not in res:
|
||||
res.append(x)
|
||||
return res
|
||||
|
||||
|
||||
def grid_search(d: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
ret = []
|
||||
if isinstance(d, dict):
|
||||
if len(d) == 0:
|
||||
return [dict()]
|
||||
copy = d.copy()
|
||||
k, v = copy.popitem()
|
||||
configs = unique(grid_search(v))
|
||||
rest = grid_search(copy)
|
||||
for r in rest:
|
||||
for config in configs:
|
||||
ret.append(r | {k: config})
|
||||
elif isinstance(d, list):
|
||||
for v in d:
|
||||
configs = grid_search(v)
|
||||
for config in configs:
|
||||
ret.append(config)
|
||||
else:
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
from torch import nn
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
from pytorch_fob.engine.utils import some, log_warn
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParameterGroup():
|
||||
named_parameters: dict[str, Parameter]
|
||||
lr_multiplier: Optional[float] = field(default=None)
|
||||
weight_decay_multiplier: Optional[float] = field(default=None)
|
||||
optimizer_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __and__(self, other) -> "ParameterGroup":
|
||||
assert isinstance(other, ParameterGroup)
|
||||
n1 = set(self.named_parameters.keys())
|
||||
n2 = set(other.named_parameters.keys())
|
||||
all_params = self.named_parameters | other.named_parameters
|
||||
n12 = n1 & n2
|
||||
new_params = {n: all_params[n] for n in n12}
|
||||
return ParameterGroup(
|
||||
named_parameters=new_params,
|
||||
lr_multiplier=some(other.lr_multiplier, default=self.lr_multiplier),
|
||||
weight_decay_multiplier=some(other.weight_decay_multiplier, default=self.weight_decay_multiplier),
|
||||
optimizer_kwargs=self.optimizer_kwargs | other.optimizer_kwargs
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.named_parameters)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return not self.empty()
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.named_parameters) == 0
|
||||
|
||||
def to_optimizer_dict(
|
||||
self,
|
||||
lr: Optional[float] = None,
|
||||
weight_decay: Optional[float] = None
|
||||
) -> dict[str, list[Parameter] | Any]:
|
||||
names = sorted(self.named_parameters)
|
||||
d = {
|
||||
"params": [self.named_parameters[n] for n in names],
|
||||
"names": names,
|
||||
**self.optimizer_kwargs
|
||||
}
|
||||
if lr is not None:
|
||||
d["lr"] = self.lr_multiplier * lr if self.lr_multiplier is not None else lr
|
||||
if weight_decay is not None:
|
||||
d["weight_decay"] = self.weight_decay_multiplier * weight_decay \
|
||||
if self.weight_decay_multiplier is not None else weight_decay
|
||||
return d
|
||||
|
||||
|
||||
class GroupedModel(Module):
|
||||
"""
|
||||
Wrapper around a nn.Module to allow specifying different optimizer settings for different parameters.
|
||||
To use this feature for your task, inherit from this class and override the `parameter_groups` method.
|
||||
Then simply wrap your model before passing it to the `__init__` method of the `TaskModel` superclass.
|
||||
"""
|
||||
def __init__(self, model: Module) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model.forward(*args, **kwargs)
|
||||
|
||||
def parameter_groups(self) -> list[ParameterGroup]:
|
||||
return wd_group_named_parameters(self.model)
|
||||
|
||||
def grouped_parameters(
|
||||
self,
|
||||
lr: Optional[float] = None,
|
||||
weight_decay: Optional[float] = None
|
||||
) -> list[dict[str, list[Parameter] | Any]]:
|
||||
return [pg.to_optimizer_dict(lr, weight_decay) for pg in self.parameter_groups()]
|
||||
|
||||
|
||||
def merge_parameter_splits(split1: list[ParameterGroup], split2: list[ParameterGroup]) -> list[ParameterGroup]:
|
||||
"""
|
||||
Merge two lists of ParameterGroup objects into a single list.
|
||||
Assumes that both input lists partition the parameters.
|
||||
"""
|
||||
groups = []
|
||||
for pg1 in split1:
|
||||
for pg2 in split2:
|
||||
pg12 = pg1 & pg2
|
||||
if not pg12.empty():
|
||||
groups.append(pg12)
|
||||
return groups
|
||||
|
||||
|
||||
def group_named_parameters(
|
||||
model: Module,
|
||||
g1_conds: Iterable[Callable] = (lambda *_: True,),
|
||||
g2_conds: Iterable[Callable] = (lambda *_: True,),
|
||||
special_conds: Iterable[Callable] = tuple(),
|
||||
ignore_conds: Iterable[Callable] = tuple(),
|
||||
g1_kwargs: Optional[dict[str, Any]] = None,
|
||||
g2_kwargs: Optional[dict[str, Any]] = None,
|
||||
debug: bool = False
|
||||
) -> list[ParameterGroup]:
|
||||
"""
|
||||
Group named parameters based on specified conditions and return a list of ParameterGroup objects.
|
||||
|
||||
Args:
|
||||
model (Module): The neural network model.
|
||||
g1_conds (Iterable[Callable]): Conditions for selecting parameters for group 1.
|
||||
g2_conds (Iterable[Callable]): Conditions for selecting parameters for group 2.
|
||||
special_conds (Iterable[Callable]): Conditions for selecting special parameters that should not be grouped.
|
||||
ignore_conds (Iterable[Callable]): Conditions for ignoring parameters (e.g. if they occur in submodules).
|
||||
g1_kwargs (Optional[dict[str, Any]]): Additional keyword arguments for constructor of group 1.
|
||||
g2_kwargs (Optional[dict[str, Any]]): Additional keyword arguments for constructor of group 2.
|
||||
|
||||
Returns:
|
||||
List[ParameterGroup]: A list of ParameterGroup objects containing named parameters.
|
||||
"""
|
||||
g1_kwargs = g1_kwargs if g1_kwargs is not None else {}
|
||||
g2_kwargs = g2_kwargs if g2_kwargs is not None else {}
|
||||
s1 = set()
|
||||
s2 = set()
|
||||
special = set()
|
||||
param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
|
||||
for mn, m in model.named_modules():
|
||||
for pn, p in m.named_parameters():
|
||||
fpn = f"{mn}.{pn}" if mn else pn # full param name
|
||||
if not p.requires_grad or fpn not in param_dict:
|
||||
continue # frozen weights
|
||||
elif any(c(m, p, fpn) for c in ignore_conds):
|
||||
continue
|
||||
elif any(c(m, p, fpn) for c in special_conds):
|
||||
special.add(fpn)
|
||||
elif any(c(m, p, fpn) for c in g1_conds):
|
||||
s1.add(fpn)
|
||||
elif any(c(m, p, fpn) for c in g2_conds):
|
||||
s2.add(fpn)
|
||||
elif debug:
|
||||
log_warn("group_named_parameters: Not using any rule for ", fpn, " in ", type(m))
|
||||
|
||||
s1 |= (param_dict.keys() - s2 - special)
|
||||
|
||||
# validate that we considered every parameter
|
||||
inter_params = s1 & s2
|
||||
union_params = s1 | s2
|
||||
assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both s1/s2 sets!"
|
||||
assert len(
|
||||
param_dict.keys() - special - union_params) == 0, \
|
||||
f"parameters {str(param_dict.keys() - union_params)} \
|
||||
were not separated into either s1/s2 set!"
|
||||
|
||||
if not s2:
|
||||
param_groups = [ParameterGroup(
|
||||
named_parameters=dict(zip(sorted(union_params), (param_dict[pn] for pn in sorted(union_params))))
|
||||
)]
|
||||
else:
|
||||
param_groups = [
|
||||
ParameterGroup(
|
||||
named_parameters=dict(zip(sorted(s1), (param_dict[pn] for pn in sorted(s1)))),
|
||||
**g1_kwargs
|
||||
),
|
||||
ParameterGroup(
|
||||
named_parameters=dict(zip(sorted(s2), (param_dict[pn] for pn in sorted(s2)))),
|
||||
**g2_kwargs
|
||||
),
|
||||
]
|
||||
|
||||
return param_groups
|
||||
|
||||
|
||||
def wd_group_named_parameters(model: Module) -> list[ParameterGroup]:
|
||||
whitelist_weight_modules = (nn.Linear, nn.modules.conv._ConvNd) # pylint: disable=protected-access # noqa
|
||||
blacklist_weight_modules = (nn.modules.batchnorm._NormBase, # pylint: disable=protected-access # noqa
|
||||
nn.GroupNorm, nn.LayerNorm,
|
||||
nn.LocalResponseNorm,
|
||||
nn.Embedding)
|
||||
ignore_modules = (nn.Sequential,)
|
||||
apply_decay_conds = [lambda m, _, pn: pn.endswith('weight') and isinstance(m, whitelist_weight_modules)]
|
||||
apply_no_decay_conds = [lambda m, _, pn: pn.endswith('bias') or isinstance(m, blacklist_weight_modules)]
|
||||
special_conds = [lambda m, p, pn: hasattr(p, '_optim')]
|
||||
ignore_conds = [lambda m, p, pn: isinstance(m, ignore_modules)]
|
||||
|
||||
return group_named_parameters(
|
||||
model,
|
||||
g1_conds=apply_decay_conds,
|
||||
g2_conds=apply_no_decay_conds,
|
||||
special_conds=special_conds,
|
||||
ignore_conds=ignore_conds,
|
||||
g2_kwargs={'weight_decay_multiplier': 0.0}
|
||||
)
|
||||
|
||||
|
||||
def resolve_parameter_dicts(dict1: dict[str, Any], dict2: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
p1, p2 = dict1["params"], dict2["params"]
|
||||
n1, n2 = set(dict1["names"]), set(dict2["names"])
|
||||
n_to_p1 = dict(zip(dict1["names"], dict1["params"]))
|
||||
n_to_p2 = dict(zip(dict2["names"], dict2["params"]))
|
||||
assert len(n1) == len(p1)
|
||||
assert len(n2) == len(p2)
|
||||
kwarg1 = {k: v for k, v in dict1.items() if k not in ["params", "names"]}
|
||||
kwarg2 = {k: v for k, v in dict2.items() if k not in ["params", "names"]}
|
||||
n1_and_n2 = n1 & n2
|
||||
n1_no_n2 = n1 - n2
|
||||
n2_no_n1 = n2 - n1
|
||||
assert n1_and_n2 | n1_no_n2 | n2_no_n1 == n1 | n2
|
||||
outdict1 = {"params": [n_to_p1[n] for n in sorted(n1_no_n2)],
|
||||
"names": sorted(n1_no_n2), **kwarg1}
|
||||
outdict2 = {"params": [n_to_p2[n] for n in sorted(n2_no_n1)],
|
||||
"names": sorted(n2_no_n1), **kwarg2}
|
||||
# kwarg2 takes precedence if an arg is present in both dicts:
|
||||
outdict12 = {"params": [{**n_to_p1, **n_to_p2}[n] for n in sorted(n1_and_n2)],
|
||||
"names": sorted(n1_and_n2), **kwarg1, **kwarg2}
|
||||
return [outdict1, outdict2, outdict12]
|
||||
|
||||
|
||||
def intersect_parameter_dicts(dict1: dict[str, Any], dict2: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
d = resolve_parameter_dicts(dict1, dict2)[2]
|
||||
return d if len(d["params"]) > 0 else None
|
||||
|
||||
|
||||
def merge_parameter_dicts(dict1: dict[str, Any], dict2: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
d = resolve_parameter_dicts(dict1, dict2)
|
||||
return list(filter(lambda x: len(x["params"]) > 0, d))
|
||||
65
environments/optimizer/FOB/pytorch_fob/engine/parser.py
Normal file
65
environments/optimizer/FOB/pytorch_fob/engine/parser.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
from pathlib import Path
|
||||
from typing import Any, Iterable, Optional
|
||||
import re
|
||||
import yaml
|
||||
|
||||
|
||||
class YAMLParser():
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def parse_yaml(self, file: Path) -> Any:
|
||||
"""
|
||||
Opens and parses a YAML file.
|
||||
"""
|
||||
with open(file, "r", encoding="utf8") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def parse_yamls_and_extra_args(self,
|
||||
default_yaml: Path,
|
||||
custom_yaml: Optional[Path],
|
||||
additional_args: Iterable[str] = tuple()
|
||||
) -> dict:
|
||||
"""assumes that there is a dict in the yaml"""
|
||||
config_to_use = self.parse_yaml(default_yaml)
|
||||
if custom_yaml is not None:
|
||||
user_yaml = self.parse_yaml(custom_yaml)
|
||||
# merge in place
|
||||
self.merge_dicts_hierarchical(lo=config_to_use, hi=user_yaml)
|
||||
self.parse_args_into_searchspace(config_to_use, additional_args)
|
||||
return config_to_use
|
||||
|
||||
def parse_args_into_searchspace(self, searchspace: dict[str, Any], args: Iterable[str]):
|
||||
"""
|
||||
Overwrites args given in the form of 'this.that=something'. Also supports lists: 'this.that[0]=something'
|
||||
"""
|
||||
for arg in args:
|
||||
self._parse_arg_into_searchspace(searchspace, arg)
|
||||
|
||||
def _parse_arg_into_searchspace(self, searchspace: dict[str, Any], arg: str):
|
||||
keys, value = arg.split("=")
|
||||
keys = keys.split(".")
|
||||
keys_with_list_indices = []
|
||||
for key in keys:
|
||||
match = re.search(r"^(.*?)\[(\-?\d+)\]$", key)
|
||||
if match:
|
||||
keys_with_list_indices.append(match.group(1))
|
||||
keys_with_list_indices.append(int(match.group(2)))
|
||||
else:
|
||||
keys_with_list_indices.append(key)
|
||||
target = searchspace
|
||||
for key in keys_with_list_indices[:-1]:
|
||||
if isinstance(target, dict) and key not in target:
|
||||
target[key] = {}
|
||||
target = target[key]
|
||||
target[keys_with_list_indices[-1]] = yaml.safe_load(value)
|
||||
|
||||
def merge_dicts_hierarchical(self, lo: dict, hi: dict):
|
||||
"""
|
||||
Overwrites values in `lo` with values from `hi` if they are present in both/
|
||||
"""
|
||||
for k, v in hi.items():
|
||||
if isinstance(v, dict) and isinstance(lo.get(k, None), dict):
|
||||
self.merge_dicts_hierarchical(lo[k], v)
|
||||
else:
|
||||
lo[k] = v
|
||||
298
environments/optimizer/FOB/pytorch_fob/engine/run.py
Normal file
298
environments/optimizer/FOB/pytorch_fob/engine/run.py
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
import hashlib
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from lightning import Callback, LightningDataModule, LightningModule, Trainer, seed_everything
|
||||
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
|
||||
from lightning.pytorch.loggers import Logger, TensorBoardLogger, CSVLogger
|
||||
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
|
||||
import torch
|
||||
import yaml
|
||||
from pytorch_fob.engine.callbacks import LogTrainingStats, OptimizerTime, PrintEpochWithTime, RestrictTrainEpochs
|
||||
from pytorch_fob.engine.configs import EngineConfig, EvalConfig, OptimizerConfig, TaskConfig
|
||||
from pytorch_fob.engine.utils import AttributeDict, EndlessList, calculate_steps, concatenate_dict_keys, convert_type_inside_dict, dict_differences, findfirst, path_to_str_inside_dict, precision_with_fallback, seconds_to_str, trainer_strategy, write_results, log_warn, log_info
|
||||
from pytorch_fob.optimizers.optimizers import Optimizer
|
||||
from pytorch_fob.tasks.tasks import TaskDataModule, TaskModel, import_task
|
||||
|
||||
|
||||
class Run():
|
||||
def __init__(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
default_config: dict[str, Any],
|
||||
task_key: str,
|
||||
optimizer_key: str,
|
||||
engine_key: str,
|
||||
eval_key: str,
|
||||
identifier_key: str
|
||||
) -> None:
|
||||
"""
|
||||
setup: download and prepare data before creating the Run
|
||||
"""
|
||||
self._config = config
|
||||
self._default_config = default_config
|
||||
self.task_key = task_key
|
||||
self.optimizer_key = optimizer_key
|
||||
self.engine_key = engine_key
|
||||
self.eval_key = eval_key
|
||||
self.identifier_key = identifier_key
|
||||
self._generate_configs()
|
||||
self._set_outpath()
|
||||
self._callbacks = AttributeDict({})
|
||||
|
||||
def start(self) -> dict[str, _EVALUATE_OUTPUT]:
|
||||
self.run_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.export_config()
|
||||
scores: dict[str, _EVALUATE_OUTPUT] = {}
|
||||
if any([self.engine.train, self.engine.test]):
|
||||
self._ensure_resume_path()
|
||||
self.ensure_max_steps()
|
||||
torch.set_float32_matmul_precision('high')
|
||||
seed_everything(self.engine.seed, workers=True)
|
||||
model, data_module = self.get_task()
|
||||
if self.engine.train:
|
||||
trainer = self.get_trainer()
|
||||
self._train(trainer, model, data_module)
|
||||
scores["mean_optimizer_time_ms"] = self._callbacks["optimizer_time"].total_mean_optimizer_step_time_ms
|
||||
if self.engine.validate:
|
||||
scores["validation"] = self._validate(trainer, model, data_module)
|
||||
if self.engine.test:
|
||||
tester = self.get_tester()
|
||||
if self.engine.train: # no need to load last checkpoint, model is already loaded
|
||||
ckpt = None
|
||||
elif self.engine.resume is not None:
|
||||
ckpt=self.engine.resume
|
||||
else:
|
||||
log_warn(
|
||||
"No last checkpoint found, evaluating untrained model. " + \
|
||||
"If this is unexpected, try to set 'engine.resume=true'."
|
||||
)
|
||||
ckpt = None
|
||||
scores["test_final"] = self._test(tester, model, data_module, ckpt=ckpt) # type: ignore (see ensure_resume_path)
|
||||
best_path = self.get_best_checkpoint()
|
||||
if best_path is not None:
|
||||
scores["test_best"] = self._test(tester, model, data_module, Path(best_path))
|
||||
else:
|
||||
log_info("No best checkpoint found, skipping test.")
|
||||
write_results(scores, self.run_dir / "scores.json")
|
||||
return scores
|
||||
|
||||
def _train(self, trainer: Trainer, model: LightningModule, data_module: LightningDataModule):
|
||||
start_time = time.time()
|
||||
if self.engine.accelerator == "gpu" and torch.cuda.is_available():
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=(self.engine.optimize_memory or not self.engine.deterministic)
|
||||
):
|
||||
trainer.fit(model, datamodule=data_module, ckpt_path=self.engine.resume) # type: ignore
|
||||
else:
|
||||
trainer.fit(model, datamodule=data_module, ckpt_path=self.engine.resume) # type: ignore
|
||||
end_time = time.time()
|
||||
train_time = int(end_time - start_time)
|
||||
log_info(f"Finished training in {seconds_to_str(train_time)}.")
|
||||
|
||||
# Write train_time.txt
|
||||
train_time_path = self.run_dir / "train_time.txt"
|
||||
with open(train_time_path, "w") as f:
|
||||
f.write(str(train_time) + "\n")
|
||||
|
||||
def _validate(self, trainer: Trainer, model: LightningModule, data_module: LightningDataModule) -> _EVALUATE_OUTPUT:
|
||||
score = trainer.validate(model, datamodule=data_module)
|
||||
return score
|
||||
|
||||
def _test(self, tester: Trainer, model: LightningModule, data_module: LightningDataModule, ckpt: Optional[Path] = None) -> _EVALUATE_OUTPUT:
|
||||
ckpt_path = self.engine.resume if ckpt is None else ckpt
|
||||
mode = "final" if ckpt_path is None or ckpt_path.stem.startswith("last") else "best" # type: ignore
|
||||
log_info(f"Testing {mode} checkpoint...")
|
||||
score = tester.test(model, datamodule=data_module, ckpt_path=ckpt_path) # type: ignore
|
||||
write_results(score, self.run_dir / f"results_{mode}_model.json")
|
||||
return score
|
||||
|
||||
def export_config(self):
|
||||
with open(self.run_dir / "config.yaml", "w", encoding="utf8") as f:
|
||||
d = path_to_str_inside_dict(self._config)
|
||||
d = convert_type_inside_dict(d, EndlessList, list)
|
||||
yaml.safe_dump(d, f)
|
||||
|
||||
def export_config_dict(self) -> dict[str, Any]:
|
||||
d = path_to_str_inside_dict(self._config)
|
||||
d = convert_type_inside_dict(d, EndlessList, list)
|
||||
return d
|
||||
|
||||
def get_config(self) -> AttributeDict:
|
||||
return AttributeDict(self._config)
|
||||
|
||||
def get_optimizer(self) -> Optimizer:
|
||||
return Optimizer(self.optimizer)
|
||||
|
||||
def get_task(self) -> tuple[TaskModel, TaskDataModule]:
|
||||
task_module = import_task(self.task.name)
|
||||
return task_module.get_task(self.get_optimizer(), self.task)
|
||||
|
||||
def get_datamodule(self) -> TaskDataModule:
|
||||
task_module = import_task(self.task.name)
|
||||
return task_module.get_datamodule(self.task)
|
||||
|
||||
def get_callbacks(self) -> list[Callback]:
|
||||
if len(self._callbacks) < 1:
|
||||
self._init_callbacks()
|
||||
return list(self._callbacks.values())
|
||||
|
||||
def get_loggers(self) -> list[Logger]:
|
||||
return [
|
||||
TensorBoardLogger(
|
||||
save_dir=self.run_dir,
|
||||
name="tb_logs"
|
||||
),
|
||||
CSVLogger(
|
||||
save_dir=self.run_dir,
|
||||
name="csv_logs"
|
||||
)
|
||||
]
|
||||
|
||||
def get_trainer(self) -> Trainer:
|
||||
return Trainer(
|
||||
max_steps=self.engine.max_steps,
|
||||
logger=self.get_loggers(),
|
||||
callbacks=self.get_callbacks(),
|
||||
devices=self.engine.devices,
|
||||
strategy=trainer_strategy(self.engine.devices),
|
||||
enable_progress_bar=(not self.engine.silent),
|
||||
deterministic=self.engine.deterministic,
|
||||
detect_anomaly=self.engine.detect_anomaly,
|
||||
gradient_clip_val=self.engine.gradient_clip_val,
|
||||
gradient_clip_algorithm=self.engine.gradient_clip_alg,
|
||||
precision=precision_with_fallback(self.engine.precision), # type: ignore
|
||||
accelerator=self.engine.accelerator,
|
||||
log_every_n_steps=self.engine.logging_inteval
|
||||
)
|
||||
|
||||
def get_tester(self) -> Trainer:
|
||||
return Trainer(
|
||||
devices=1,
|
||||
logger=False,
|
||||
enable_progress_bar=(not self.engine.silent),
|
||||
deterministic=self.engine.deterministic,
|
||||
precision=precision_with_fallback(self.engine.precision), # type: ignore
|
||||
accelerator=self.engine.accelerator
|
||||
)
|
||||
|
||||
def get_best_checkpoint(self) -> Optional[Path]:
|
||||
model_checkpoint = self._callbacks.get("best_model_checkpoint", None)
|
||||
if model_checkpoint is not None:
|
||||
model_checkpoint = Path(model_checkpoint.best_model_path)
|
||||
model_checkpoint = model_checkpoint if not model_checkpoint.is_dir() else None
|
||||
if model_checkpoint is None:
|
||||
available_checkpoints = self.get_available_checkpoints()
|
||||
model_checkpoint = findfirst(lambda x: x.stem.startswith("best"), available_checkpoints)
|
||||
return model_checkpoint
|
||||
|
||||
def get_available_checkpoints(self) -> list[Path]:
|
||||
if self.checkpoint_dir.exists():
|
||||
return list(filter(lambda x: x.suffix == ".ckpt", self.checkpoint_dir.iterdir()))
|
||||
return []
|
||||
|
||||
def ensure_max_steps(self):
|
||||
"""
|
||||
Ensures that `self.task.max_steps` is calculated and set correctly.
|
||||
"""
|
||||
if self.task.max_steps is None:
|
||||
max_steps = self._calc_max_steps()
|
||||
self._config[self.task_key]["max_steps"] = max_steps
|
||||
if self._default_config[self.task_key]["max_steps"] is None:
|
||||
self._default_config[self.task_key]["max_steps"] = max_steps
|
||||
self._generate_configs()
|
||||
log_info(f"'max_steps' not set explicitly, using {max_steps=} (calculated from " +
|
||||
f"max_epochs={self.task.max_epochs}, batch_size={self.task.batch_size}, devices={self.engine.devices})")
|
||||
|
||||
def _ensure_resume_path(self):
|
||||
"""
|
||||
Ensures that `self.engine.resume` is either a valid Path or None.
|
||||
"""
|
||||
if isinstance(self.engine.resume, Path):
|
||||
pass
|
||||
elif isinstance(self.engine.resume, bool):
|
||||
resume_path = None
|
||||
if self.engine.resume:
|
||||
available_checkpoints = self.get_available_checkpoints()
|
||||
if len(available_checkpoints) < 1:
|
||||
log_warn("engine.resume=True but no checkpoint was found. Starting run from scratch.")
|
||||
else:
|
||||
resume_path = findfirst(lambda x: x.stem == "last", available_checkpoints)
|
||||
self._config[self.engine_key]["resume"] = resume_path
|
||||
self._generate_configs()
|
||||
else:
|
||||
raise TypeError(f"Unsupportet type for 'resume', got {type(self.engine.resume)=}.")
|
||||
|
||||
def _calc_max_steps(self) -> int:
|
||||
dm = self.get_datamodule()
|
||||
dm.setup("fit")
|
||||
train_samples = len(dm.data_train)
|
||||
return calculate_steps(self.task.max_epochs, train_samples, self.engine.devices, self.task.batch_size)
|
||||
|
||||
def _init_callbacks(self):
|
||||
self._callbacks["optimizer_time"] = OptimizerTime()
|
||||
self._callbacks["best_model_checkpoint"] = ModelCheckpoint(
|
||||
dirpath=self.checkpoint_dir,
|
||||
filename="best-{epoch}-{step}",
|
||||
monitor=self.task.target_metric,
|
||||
mode=self.task.target_metric_mode
|
||||
)
|
||||
self._callbacks["model_checkpoint"] = ModelCheckpoint(
|
||||
dirpath=self.checkpoint_dir,
|
||||
enable_version_counter=False,
|
||||
every_n_epochs=1,
|
||||
save_last=True
|
||||
)
|
||||
if self.engine.early_stopping is not None:
|
||||
self._callbacks["early_stopping"] = EarlyStopping(
|
||||
monitor=self.engine.early_stopping_metric,
|
||||
mode=self.task.target_metric_mode,
|
||||
patience=self.engine.early_stopping,
|
||||
check_finite=self.engine.check_finite,
|
||||
log_rank_zero_only=True
|
||||
)
|
||||
self._callbacks["lr_monitor"] = LearningRateMonitor(
|
||||
logging_interval=self.optimizer.lr_interval
|
||||
)
|
||||
if self.engine.log_extra:
|
||||
self._callbacks["extra"] = LogTrainingStats(
|
||||
log_every_n_steps=self.engine.logging_inteval,
|
||||
**(self.engine.log_extra if isinstance(self.engine.log_extra, dict) else {})
|
||||
)
|
||||
self._callbacks["print_epoch"] = PrintEpochWithTime(self.engine.silent)
|
||||
if self.engine.restrict_train_epochs is not None:
|
||||
self._callbacks["restrict_train_epochs"] = RestrictTrainEpochs(self.engine.restrict_train_epochs)
|
||||
# TODO: callback for logging time per step
|
||||
|
||||
def outpath_exclude_keys(self) -> list[str]:
|
||||
return [
|
||||
self.eval_key,
|
||||
"output_dir_name"
|
||||
]
|
||||
|
||||
def _set_outpath(self):
|
||||
base: Path = self.engine.output_dir / self.task.output_dir_name / self.optimizer.output_dir_name
|
||||
exclude_keys = self.outpath_exclude_keys()
|
||||
exclude_keys += self.engine.outpath_irrelevant_engine_keys()
|
||||
diffs = concatenate_dict_keys(dict_differences(self._config, self._default_config), exclude_keys=exclude_keys)
|
||||
run_dir = ",".join(f"{k}={str(v)}" for k, v in sorted(diffs.items())) if diffs else "default"
|
||||
if len(run_dir) > 254: # max file name length
|
||||
hashdir = hashlib.md5(run_dir.encode()).hexdigest()
|
||||
log_info(f"folder name {run_dir} is too long, using {hashdir} instead.")
|
||||
run_dir = hashdir
|
||||
self.run_dir = base / run_dir
|
||||
self.checkpoint_dir = self.run_dir / "checkpoints"
|
||||
|
||||
def _generate_configs(self):
|
||||
self.engine = EngineConfig(self._config, self.task_key, self.engine_key)
|
||||
self.optimizer = OptimizerConfig(self._config, self.optimizer_key, self.task_key, self.identifier_key)
|
||||
self.task = TaskConfig(self._config, self.task_key, self.engine_key, self.identifier_key)
|
||||
self.evaluation = EvalConfig(
|
||||
self._config,
|
||||
eval_key=self.eval_key,
|
||||
engine_key=self.engine_key,
|
||||
ignore_keys=self.engine.outpath_irrelevant_engine_keys(prefix=f"{self.engine_key}.") + [f"{self.optimizer_key}.output_dir_name", f"{self.task_key}.output_dir_name"]
|
||||
)
|
||||
171
environments/optimizer/FOB/pytorch_fob/engine/run_schedulers.py
Normal file
171
environments/optimizer/FOB/pytorch_fob/engine/run_schedulers.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Iterable, Optional, Sequence
|
||||
import traceback
|
||||
import yaml
|
||||
from pytorch_fob.engine.run import Run
|
||||
from pytorch_fob.engine.slurm import Slurm
|
||||
from pytorch_fob.engine.utils import log_info, log_warn, seconds_to_str, some, str_to_seconds
|
||||
|
||||
|
||||
FOB_RUN_SCRIPT = "pytorch_fob.run_experiment"
|
||||
FOB_EVAL_SCRIPT = "pytorch_fob.evaluate_experiment"
|
||||
|
||||
|
||||
def argcheck_allequal_engine(
|
||||
runs: list[Run],
|
||||
keys: list[str],
|
||||
reason: str = "'engine.run_scheduler=slurm_array'"
|
||||
) -> None:
|
||||
ok = True
|
||||
first = runs[0]
|
||||
for key in keys:
|
||||
if not all(run.engine[key] == first.engine[key] for run in runs[1:]):
|
||||
ok = False
|
||||
break
|
||||
if not ok:
|
||||
req = ", ".join(map(lambda s: "engine." + s, keys))
|
||||
raise ValueError(f"All runs must have the same values for {req} when using {reason}")
|
||||
|
||||
|
||||
def export_experiment(run: Run, experiment: dict[str, Any]) -> Path:
|
||||
run.run_dir.mkdir(parents=True, exist_ok=True)
|
||||
outfile = run.run_dir / "experiment.yaml"
|
||||
with open(outfile, "w", encoding="utf8") as f:
|
||||
yaml.safe_dump(experiment, f)
|
||||
return outfile
|
||||
|
||||
|
||||
def process_args(args: dict[str, str], run: Run) -> None:
|
||||
if "time" in args:
|
||||
time = args["time"]
|
||||
seconds = str_to_seconds(time) if isinstance(time, str) else time
|
||||
args["time"] = seconds_to_str(int(run.engine.sbatch_time_factor * seconds))
|
||||
if "gres" not in args and "gpus" not in args:
|
||||
args["gres"] = f"gpu:{run.engine.devices}"
|
||||
if not any(k.startswith("ntasks") for k in args):
|
||||
args["ntasks-per-node"] = str(run.engine.devices)
|
||||
if not any(k.startswith("cpus") for k in args):
|
||||
args["cpus-per-task"] = str(run.engine.workers)
|
||||
|
||||
|
||||
def wrap_template(template_path: Optional[Path], command: str, placeholder: str = "__FOB_COMMAND__") -> str:
|
||||
if template_path is not None:
|
||||
with open(template_path, "r", encoding="utf8") as f:
|
||||
template = f.read()
|
||||
if placeholder in template:
|
||||
command = template.replace(placeholder, command)
|
||||
else:
|
||||
command = f"{template}\n{command}\n"
|
||||
return command
|
||||
|
||||
|
||||
def get_command(experiment_file: Path, index: Optional[str], plot: bool) -> str:
|
||||
run_script = FOB_EVAL_SCRIPT if plot else FOB_RUN_SCRIPT
|
||||
disable_plot = "" if plot else "engine.plot=false"
|
||||
scheduler = "" if index is None else f"engine.run_scheduler=single:{index}"
|
||||
return f"""srun python -m {run_script} {experiment_file} {scheduler} {disable_plot}"""
|
||||
|
||||
|
||||
def get_job_name(run: Run) -> str:
|
||||
return f"FOB-{run.task.name}-{run.optimizer.name}"
|
||||
|
||||
|
||||
def get_slurm(job_name: str, args: dict[str, str], log_dir: Path, scripts_dir: Path) -> Slurm:
|
||||
return Slurm(
|
||||
job_name,
|
||||
args,
|
||||
log_dir=str(log_dir.resolve()),
|
||||
scripts_dir=str(scripts_dir.resolve()),
|
||||
bash_strict=False # TODO: maybe add arg or just remove 'nounset'
|
||||
)
|
||||
|
||||
|
||||
def run_slurm(
|
||||
job_name: str,
|
||||
command: str,
|
||||
args: dict[str, str],
|
||||
log_dir: Path,
|
||||
save_sbatch_scripts: Optional[Path] = None,
|
||||
dependencies: Sequence[int] = tuple(),
|
||||
dependency_type: str = "afterok"
|
||||
) -> Optional[int]:
|
||||
if save_sbatch_scripts is None:
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
s = get_slurm(job_name, args, log_dir, scripts_dir=Path(tmpdir).resolve())
|
||||
return s.run(command, name_addition="", depends_on=dependencies, dependency_type=dependency_type)
|
||||
else:
|
||||
s = get_slurm(job_name, args, log_dir, scripts_dir=save_sbatch_scripts)
|
||||
return s.run(command, name_addition="", depends_on=dependencies, dependency_type=dependency_type)
|
||||
|
||||
|
||||
def run_plotting_job(
|
||||
experiment_file: Path,
|
||||
args: dict[str, str],
|
||||
log_dir: Path,
|
||||
dependencies: Sequence[int],
|
||||
template: Optional[Path] = None
|
||||
) -> None:
|
||||
args["time"] = seconds_to_str(300) # 5 minutes should be plenty of time to plot
|
||||
args.pop("array", None)
|
||||
# no gpus needed for plotting
|
||||
args.pop("gpus", None)
|
||||
args.pop("gres", None)
|
||||
# just one cpu per node for plotting
|
||||
remove_keys = [k for k in args.keys() if k.startswith("ntasks") or k.startswith("cpus")]
|
||||
for k in remove_keys:
|
||||
args.pop(k)
|
||||
args["nodes"] = "1"
|
||||
args["ntasks-per-node"] = "1"
|
||||
args["cpus-per-task"] = "2"
|
||||
command = get_command(experiment_file, None, plot=True)
|
||||
command = wrap_template(template, command)
|
||||
run_slurm("FOB-plot", command, args, log_dir, dependencies=dependencies, dependency_type="afterany")
|
||||
|
||||
|
||||
def slurm_array(runs: list[Run], experiment: dict[str, Any]) -> None:
|
||||
equal_req = ["devices", "workers", "sbatch_args", "slurm_log_dir", "sbatch_script_template", "run_scheduler"]
|
||||
argcheck_allequal_engine(runs, equal_req)
|
||||
run = runs[0] # all runs have the same args
|
||||
args = run.engine.sbatch_args
|
||||
log_dir = some(run.engine.slurm_log_dir, default=run.engine.output_dir / "slurm_logs")
|
||||
if "array" not in args:
|
||||
args["array"] = f"1-{len(runs)}"
|
||||
process_args(args, run)
|
||||
experiment_file = [export_experiment(run, experiment).resolve() for run in runs][0]
|
||||
command = get_command(experiment_file, "$SLURM_ARRAY_TASK_ID", plot=False)
|
||||
command = wrap_template(run.engine.sbatch_script_template, command)
|
||||
job_id = run_slurm(get_job_name(run), command, args, log_dir, save_sbatch_scripts=run.engine.save_sbatch_scripts)
|
||||
if job_id is not None and run.engine.plot:
|
||||
run_plotting_job(experiment_file, args, log_dir, [job_id], template=run.engine.sbatch_script_template)
|
||||
|
||||
|
||||
def slurm_jobs(runs: list[Run], experiment: dict[str, Any]) -> list[int]:
|
||||
job_ids = []
|
||||
experiment_file = Path()
|
||||
for i, run in enumerate(runs, start=1):
|
||||
args = run.engine.sbatch_args
|
||||
process_args(args, run)
|
||||
log_dir = some(run.engine.slurm_log_dir, default=run.run_dir / "slurm_logs")
|
||||
experiment_file = export_experiment(run, experiment).resolve()
|
||||
command = get_command(experiment_file, str(i), plot=False)
|
||||
command = wrap_template(run.engine.sbatch_script_template, command)
|
||||
job_id = run_slurm(get_job_name(run), command, args, log_dir, save_sbatch_scripts=run.engine.save_sbatch_scripts)
|
||||
if job_id is not None:
|
||||
job_ids.append(job_id)
|
||||
if len(job_ids) > 0 and any(map(lambda r: r.engine.plot, runs)):
|
||||
equal_req = ["slurm_log_dir", "sbatch_script_template"]
|
||||
argcheck_allequal_engine(runs, equal_req, reason="'engine.plot=true' with 'engine.run_scheduler=slurm_jobs'")
|
||||
run_plotting_job(experiment_file, args, log_dir, job_ids, template=runs[0].engine.sbatch_script_template)
|
||||
return job_ids
|
||||
|
||||
|
||||
def sequential(runs: Iterable[Run], n_runs: int, experiment: dict[str, Any]):
|
||||
for i, run in enumerate(runs, start=1):
|
||||
log_info(f"Starting run {i}/{n_runs}.")
|
||||
export_experiment(run, experiment)
|
||||
try:
|
||||
run.start()
|
||||
except RuntimeError as _e: # detect_anomaly raises RuntimeError
|
||||
t = traceback.format_exc()
|
||||
log_warn(f"Run {i}/{n_runs} failed with {t}.")
|
||||
181
environments/optimizer/FOB/pytorch_fob/engine/slurm.py
Normal file
181
environments/optimizer/FOB/pytorch_fob/engine/slurm.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 Brent Pedersen - Bioinformatics
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Adapted from https://github.com/brentp/slurmpy
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import atexit
|
||||
import hashlib
|
||||
import datetime
|
||||
from typing import Optional, Sequence
|
||||
|
||||
TMPL = """\
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH -e {log_dir}/{name}.%J.err
|
||||
#SBATCH -o {log_dir}/{name}.%J.out
|
||||
#SBATCH -J {name}
|
||||
|
||||
{header}
|
||||
|
||||
{bash_setup}
|
||||
|
||||
__script__"""
|
||||
|
||||
|
||||
def tmp(suffix=".sh"):
|
||||
t = tempfile.mktemp(suffix=suffix)
|
||||
atexit.register(os.unlink, t)
|
||||
return t
|
||||
|
||||
|
||||
class Slurm(object):
|
||||
def __init__(self, name, slurm_kwargs=None, tmpl=None,
|
||||
date_in_name=True, scripts_dir="slurm-scripts",
|
||||
log_dir='logs', bash_strict=True):
|
||||
if slurm_kwargs is None:
|
||||
slurm_kwargs = {}
|
||||
if tmpl is None:
|
||||
tmpl = TMPL
|
||||
self.log_dir = log_dir
|
||||
self.bash_strict = bash_strict
|
||||
|
||||
header = []
|
||||
if 'time' not in slurm_kwargs.keys():
|
||||
slurm_kwargs['time'] = '84:00:00'
|
||||
for k, v in slurm_kwargs.items():
|
||||
if len(k) > 1:
|
||||
k = "--" + k + "="
|
||||
else:
|
||||
k = "-" + k + " "
|
||||
header.append(f"#SBATCH {k}{v}")
|
||||
|
||||
# add bash setup list to collect bash script config
|
||||
bash_setup = []
|
||||
if bash_strict:
|
||||
bash_setup.append("set -eo pipefail -o nounset")
|
||||
|
||||
self.header = "\n".join(header)
|
||||
self.bash_setup = "\n".join(bash_setup)
|
||||
self.name = "".join(x for x in name.replace(
|
||||
" ", "-") if x.isalnum() or x == "-")
|
||||
self.tmpl = tmpl
|
||||
self.slurm_kwargs = slurm_kwargs
|
||||
if scripts_dir is not None:
|
||||
self.scripts_dir = os.path.abspath(scripts_dir)
|
||||
else:
|
||||
self.scripts_dir = None
|
||||
self.date_in_name = bool(date_in_name)
|
||||
|
||||
def __str__(self):
|
||||
return self.tmpl.format(name=self.name, header=self.header,
|
||||
log_dir=self.log_dir,
|
||||
bash_setup=self.bash_setup)
|
||||
|
||||
def _tmpfile(self):
|
||||
if self.scripts_dir is None:
|
||||
return tmp()
|
||||
else:
|
||||
for _dir in [self.scripts_dir, self.log_dir]:
|
||||
if not os.path.exists(_dir):
|
||||
os.makedirs(_dir)
|
||||
return f"{self.scripts_dir}/{self.name}.sh"
|
||||
|
||||
def run(self,
|
||||
command: str,
|
||||
name_addition: Optional[str] = None,
|
||||
cmd_kwargs: Optional[dict[str, str]] = None,
|
||||
_cmd: str = "sbatch",
|
||||
tries: int = 1,
|
||||
depends_on: Optional[Sequence[int]] = None,
|
||||
dependency_type: str = "afterok"
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
command: a bash command that you want to run
|
||||
name_addition: if not specified, the sha1 of the command to run
|
||||
appended to job name. if it is "date", the yyyy-mm-dd
|
||||
date will be added to the job name.
|
||||
cmd_kwargs: dict of extra arguments to fill in command
|
||||
(so command itself can be a template).
|
||||
_cmd: submit command (change to "bash" for testing).
|
||||
tries: try to run a job either this many times or until the first
|
||||
success.
|
||||
depends_on: job ids that this depends on before it is run
|
||||
dependency_type: after, afterok, afterany, afternotok
|
||||
"""
|
||||
if name_addition is None:
|
||||
name_addition = hashlib.sha1(command.encode("utf-8")).hexdigest()
|
||||
|
||||
if self.date_in_name:
|
||||
name_addition += "-" + str(datetime.date.today())
|
||||
name_addition = name_addition.strip(" -")
|
||||
|
||||
if cmd_kwargs is None:
|
||||
cmd_kwargs = {}
|
||||
|
||||
n = self.name
|
||||
self.name = self.name.strip(" -")
|
||||
self.name += ("-" + name_addition.strip(" -"))
|
||||
args = []
|
||||
for k, v in cmd_kwargs.items():
|
||||
args.append(f"export {k}={v}")
|
||||
args = "\n".join(args)
|
||||
|
||||
tmpl = str(self).replace("__script__", args + "\n###\n" + command)
|
||||
if depends_on is None or (len(depends_on) == 1 and depends_on[0] is None):
|
||||
depends_on = []
|
||||
|
||||
with open(self._tmpfile(), "w", encoding="utf8") as sh:
|
||||
sh.write(tmpl)
|
||||
|
||||
job_id = None
|
||||
for itry in range(1, tries + 1):
|
||||
args = [_cmd]
|
||||
if depends_on is not None and len(depends_on) > 0:
|
||||
dep = f"--dependency={dependency_type}:" + ":".join([str(x) for x in depends_on])
|
||||
args.append(dep)
|
||||
if itry > 1:
|
||||
mid = f"--dependency=afternotok:{job_id}"
|
||||
args.append(mid)
|
||||
args.append(sh.name)
|
||||
res = subprocess.check_output(args).strip()
|
||||
print(res.decode(), file=sys.stderr)
|
||||
self.name = n
|
||||
if not res.startswith(b"Submitted batch"):
|
||||
return None
|
||||
j_id = int(res.split()[-1])
|
||||
if itry == 1:
|
||||
job_id = j_id
|
||||
return job_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
228
environments/optimizer/FOB/pytorch_fob/engine/utils.py
Normal file
228
environments/optimizer/FOB/pytorch_fob/engine/utils.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable, Optional, Type
|
||||
import json
|
||||
import math
|
||||
import signal
|
||||
import torch
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only, rank_zero_info, rank_zero_debug, log
|
||||
|
||||
|
||||
def set_loglevel(level: str):
|
||||
pytorch_logger = logging.getLogger("lightning.pytorch")
|
||||
match level:
|
||||
case "debug":
|
||||
pytorch_logger.setLevel(logging.DEBUG)
|
||||
case "info":
|
||||
pytorch_logger.setLevel(logging.INFO)
|
||||
case "warn":
|
||||
pytorch_logger.setLevel(logging.WARNING)
|
||||
case "error":
|
||||
pytorch_logger.setLevel(logging.ERROR)
|
||||
case "silent":
|
||||
pytorch_logger.setLevel(logging.CRITICAL)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def rank_zero_print(*args: Any, **kwargs: Any):
|
||||
return print(*args, **kwargs)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def log_warn(msg: str, *args: Any, prefix: str = "[FOB WARNING] ", **kwargs: Any):
|
||||
return log.warning(f"{prefix}{msg}", *args, **kwargs)
|
||||
|
||||
|
||||
def log_info(msg: str, *args: Any, prefix: str = "[FOB INFO] ", **kwargs: Any):
|
||||
return rank_zero_info(f"{prefix}{msg}", *args, **kwargs)
|
||||
|
||||
|
||||
def log_debug(msg: str, *args: Any, prefix: str = "[FOB DEBUG] ", **kwargs: Any):
|
||||
return rank_zero_debug(f"{prefix}{msg}", *args, **kwargs)
|
||||
|
||||
|
||||
def write_results(results, filepath: Path):
|
||||
with open(filepath, "w", encoding="utf8") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
print(f"Saved results into {filepath}.")
|
||||
|
||||
|
||||
def wrap_list(x: Any) -> list[Any]:
|
||||
if isinstance(x, list):
|
||||
return x
|
||||
return [x]
|
||||
|
||||
|
||||
def calculate_steps(epochs: int, datapoints: int, devices: int, batch_size: int) -> int:
|
||||
return math.ceil(datapoints / batch_size / devices) * epochs
|
||||
|
||||
|
||||
def some(*args, default):
|
||||
"""
|
||||
returns the first argument that is not None or default.
|
||||
"""
|
||||
if len(args) < 1:
|
||||
return default
|
||||
first, *rest = args
|
||||
if first is not None:
|
||||
return first
|
||||
return some(*rest, default=default)
|
||||
|
||||
|
||||
def maybe_abspath(path: Optional[str | Path]) -> Optional[Path]:
|
||||
if path is None:
|
||||
return None
|
||||
return Path(path).resolve()
|
||||
|
||||
|
||||
def findfirst(f: Callable, xs: Iterable):
|
||||
for x in xs:
|
||||
if f(x):
|
||||
return x
|
||||
return None
|
||||
|
||||
|
||||
def trainer_strategy(devices: int | list[int] | str) -> str:
|
||||
if isinstance(devices, str):
|
||||
return "auto"
|
||||
ndevices = devices if isinstance(devices, int) else len(devices)
|
||||
return "ddp" if ndevices > 1 else "auto"
|
||||
|
||||
|
||||
def gpu_suited_for_compile():
|
||||
if torch.cuda.is_available():
|
||||
device_cap = torch.cuda.get_device_capability()
|
||||
return device_cap in ((7, 0), (8, 0), (9, 0))
|
||||
|
||||
|
||||
def precision_with_fallback(precision: str) -> str:
|
||||
"""
|
||||
Check if cuda supports bf16, if not using cuda or if not available return 16 instead of bf16
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
log_warn("Warning: No CUDA available. Results can be different!")
|
||||
return precision[2:]
|
||||
if precision.startswith("bf") and not torch.cuda.is_bf16_supported():
|
||||
log_warn("Warning: GPU does not support bfloat16. Results can be different!")
|
||||
return precision[2:]
|
||||
return precision
|
||||
|
||||
|
||||
def str_to_seconds(s: str) -> int:
|
||||
parts = s.split(":")
|
||||
assert len(parts) == 3, f"Invalid time format: {s}. Use 'HH:MM:SS'."
|
||||
return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2])
|
||||
|
||||
|
||||
def seconds_to_str(total_seconds: int, sep: str = ":") -> str:
|
||||
hours, rest = divmod(total_seconds, 3600)
|
||||
minutes, seconds = divmod(rest, 60)
|
||||
return sep.join(map(lambda x: str(x).zfill(2), [hours, minutes, seconds]))
|
||||
|
||||
|
||||
def begin_timeout(delay=10, show_threads=False):
|
||||
if show_threads:
|
||||
import sys
|
||||
import traceback
|
||||
import threading
|
||||
thread_names = {t.ident: t.name for t in threading.enumerate()}
|
||||
for thread_id, frame in sys._current_frames().items():
|
||||
print(f"Thread {thread_names.get(thread_id, thread_id)}:")
|
||||
traceback.print_stack(frame)
|
||||
print()
|
||||
signal.alarm(delay) # Timeout after 10 seconds
|
||||
|
||||
|
||||
def path_to_str_inside_dict(d: dict) -> dict:
|
||||
return convert_type_inside_dict(d, Path, str)
|
||||
|
||||
|
||||
def convert_type_inside_dict(d: dict, src: Type, tgt: Type) -> dict:
|
||||
ret = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
v = convert_type_inside_dict(v, src, tgt)
|
||||
if isinstance(v, src):
|
||||
ret[k] = tgt(v)
|
||||
else:
|
||||
ret[k] = v
|
||||
return ret
|
||||
|
||||
|
||||
def dict_differences(custom: dict[str, Any], default: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Recursively returns a dictionary with the items in `custom` that are different or missing from `default`.
|
||||
|
||||
Example:
|
||||
>>> dict_differences({"hi": 3, "bla": {"a": 2, "b": 2}}, {"hi": 2, "bla": {"a": 1, "b": 2}})
|
||||
{'hi': 3, 'bla': {'a': 2}}
|
||||
"""
|
||||
diff: dict[str, Any] = {}
|
||||
for key, value in custom.items():
|
||||
if key in default:
|
||||
default_value = default[key]
|
||||
if default_value == value:
|
||||
continue
|
||||
if isinstance(value, dict) and isinstance(default_value, dict):
|
||||
diff[key] = dict_differences(value, default_value)
|
||||
else:
|
||||
diff[key] = value
|
||||
else:
|
||||
diff[key] = value
|
||||
return diff
|
||||
|
||||
|
||||
def concatenate_dict_keys(
|
||||
d: dict[str, Any],
|
||||
parent_key: str = "",
|
||||
sep: str = ".",
|
||||
exclude_keys: Iterable[str] = tuple()
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Example:
|
||||
>>> concatenate_dict_keys({ "A": { "B": { "C": 1, "D": 2 }, "E": { "F": 3 } } })
|
||||
{'A.B.C': 1, 'A.B.D': 2, 'A.E.F': 3}
|
||||
>>> concatenate_dict_keys({ "A": { "B": { "C": 1, "D": 2 }, "E": { "F": 3 } } }, exclude_keys=["B"])
|
||||
{'A.E.F': 3}
|
||||
"""
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if k in exclude_keys:
|
||||
continue
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
nested_result = concatenate_dict_keys(v, new_key, sep, exclude_keys)
|
||||
result.update(nested_result)
|
||||
else:
|
||||
result[new_key] = v
|
||||
return result
|
||||
|
||||
|
||||
def sort_dict_recursively(d: dict) -> dict:
|
||||
sorted_dict = {}
|
||||
for k, v in sorted(d.items()):
|
||||
if isinstance(v, dict):
|
||||
sorted_dict[k] = sort_dict_recursively(v)
|
||||
else:
|
||||
sorted_dict[k] = v
|
||||
return sorted_dict
|
||||
|
||||
|
||||
class EndlessList(list):
|
||||
"""
|
||||
Returns first element if out of bounds. Otherwise same as list.
|
||||
"""
|
||||
def __getitem__(self, index):
|
||||
if index >= len(self) and len(self) > 0:
|
||||
return self[0]
|
||||
return super().__getitem__(index)
|
||||
|
||||
|
||||
class AttributeDict(dict):
|
||||
|
||||
def __getattribute__(self, key: str) -> Any:
|
||||
try:
|
||||
return super().__getattribute__(key)
|
||||
except AttributeError:
|
||||
pass
|
||||
return super().__getitem__(key)
|
||||
Loading…
Add table
Add a link
Reference in a new issue