atropos/environments/optimizer/FOB/pytorch_fob/engine/run.py
2025-05-18 16:36:28 -07:00

298 lines
14 KiB
Python

import hashlib
from pathlib import Path
import time
from typing import Any, Optional
from lightning import Callback, LightningDataModule, LightningModule, Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import Logger, TensorBoardLogger, CSVLogger
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
import torch
import yaml
from pytorch_fob.engine.callbacks import LogTrainingStats, OptimizerTime, PrintEpochWithTime, RestrictTrainEpochs
from pytorch_fob.engine.configs import EngineConfig, EvalConfig, OptimizerConfig, TaskConfig
from pytorch_fob.engine.utils import AttributeDict, EndlessList, calculate_steps, concatenate_dict_keys, convert_type_inside_dict, dict_differences, findfirst, path_to_str_inside_dict, precision_with_fallback, seconds_to_str, trainer_strategy, write_results, log_warn, log_info
from pytorch_fob.optimizers.optimizers import Optimizer
from pytorch_fob.tasks.tasks import TaskDataModule, TaskModel, import_task
class Run():
def __init__(
self,
config: dict[str, Any],
default_config: dict[str, Any],
task_key: str,
optimizer_key: str,
engine_key: str,
eval_key: str,
identifier_key: str
) -> None:
"""
setup: download and prepare data before creating the Run
"""
self._config = config
self._default_config = default_config
self.task_key = task_key
self.optimizer_key = optimizer_key
self.engine_key = engine_key
self.eval_key = eval_key
self.identifier_key = identifier_key
self._generate_configs()
self._set_outpath()
self._callbacks = AttributeDict({})
def start(self) -> dict[str, _EVALUATE_OUTPUT]:
self.run_dir.mkdir(parents=True, exist_ok=True)
self.export_config()
scores: dict[str, _EVALUATE_OUTPUT] = {}
if any([self.engine.train, self.engine.test]):
self._ensure_resume_path()
self.ensure_max_steps()
torch.set_float32_matmul_precision('high')
seed_everything(self.engine.seed, workers=True)
model, data_module = self.get_task()
if self.engine.train:
trainer = self.get_trainer()
self._train(trainer, model, data_module)
scores["mean_optimizer_time_ms"] = self._callbacks["optimizer_time"].total_mean_optimizer_step_time_ms
if self.engine.validate:
scores["validation"] = self._validate(trainer, model, data_module)
if self.engine.test:
tester = self.get_tester()
if self.engine.train: # no need to load last checkpoint, model is already loaded
ckpt = None
elif self.engine.resume is not None:
ckpt=self.engine.resume
else:
log_warn(
"No last checkpoint found, evaluating untrained model. " + \
"If this is unexpected, try to set 'engine.resume=true'."
)
ckpt = None
scores["test_final"] = self._test(tester, model, data_module, ckpt=ckpt) # type: ignore (see ensure_resume_path)
best_path = self.get_best_checkpoint()
if best_path is not None:
scores["test_best"] = self._test(tester, model, data_module, Path(best_path))
else:
log_info("No best checkpoint found, skipping test.")
write_results(scores, self.run_dir / "scores.json")
return scores
def _train(self, trainer: Trainer, model: LightningModule, data_module: LightningDataModule):
start_time = time.time()
if self.engine.accelerator == "gpu" and torch.cuda.is_available():
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=(self.engine.optimize_memory or not self.engine.deterministic)
):
trainer.fit(model, datamodule=data_module, ckpt_path=self.engine.resume) # type: ignore
else:
trainer.fit(model, datamodule=data_module, ckpt_path=self.engine.resume) # type: ignore
end_time = time.time()
train_time = int(end_time - start_time)
log_info(f"Finished training in {seconds_to_str(train_time)}.")
# Write train_time.txt
train_time_path = self.run_dir / "train_time.txt"
with open(train_time_path, "w") as f:
f.write(str(train_time) + "\n")
def _validate(self, trainer: Trainer, model: LightningModule, data_module: LightningDataModule) -> _EVALUATE_OUTPUT:
score = trainer.validate(model, datamodule=data_module)
return score
def _test(self, tester: Trainer, model: LightningModule, data_module: LightningDataModule, ckpt: Optional[Path] = None) -> _EVALUATE_OUTPUT:
ckpt_path = self.engine.resume if ckpt is None else ckpt
mode = "final" if ckpt_path is None or ckpt_path.stem.startswith("last") else "best" # type: ignore
log_info(f"Testing {mode} checkpoint...")
score = tester.test(model, datamodule=data_module, ckpt_path=ckpt_path) # type: ignore
write_results(score, self.run_dir / f"results_{mode}_model.json")
return score
def export_config(self):
with open(self.run_dir / "config.yaml", "w", encoding="utf8") as f:
d = path_to_str_inside_dict(self._config)
d = convert_type_inside_dict(d, EndlessList, list)
yaml.safe_dump(d, f)
def export_config_dict(self) -> dict[str, Any]:
d = path_to_str_inside_dict(self._config)
d = convert_type_inside_dict(d, EndlessList, list)
return d
def get_config(self) -> AttributeDict:
return AttributeDict(self._config)
def get_optimizer(self) -> Optimizer:
return Optimizer(self.optimizer)
def get_task(self) -> tuple[TaskModel, TaskDataModule]:
task_module = import_task(self.task.name)
return task_module.get_task(self.get_optimizer(), self.task)
def get_datamodule(self) -> TaskDataModule:
task_module = import_task(self.task.name)
return task_module.get_datamodule(self.task)
def get_callbacks(self) -> list[Callback]:
if len(self._callbacks) < 1:
self._init_callbacks()
return list(self._callbacks.values())
def get_loggers(self) -> list[Logger]:
return [
TensorBoardLogger(
save_dir=self.run_dir,
name="tb_logs"
),
CSVLogger(
save_dir=self.run_dir,
name="csv_logs"
)
]
def get_trainer(self) -> Trainer:
return Trainer(
max_steps=self.engine.max_steps,
logger=self.get_loggers(),
callbacks=self.get_callbacks(),
devices=self.engine.devices,
strategy=trainer_strategy(self.engine.devices),
enable_progress_bar=(not self.engine.silent),
deterministic=self.engine.deterministic,
detect_anomaly=self.engine.detect_anomaly,
gradient_clip_val=self.engine.gradient_clip_val,
gradient_clip_algorithm=self.engine.gradient_clip_alg,
precision=precision_with_fallback(self.engine.precision), # type: ignore
accelerator=self.engine.accelerator,
log_every_n_steps=self.engine.logging_inteval
)
def get_tester(self) -> Trainer:
return Trainer(
devices=1,
logger=False,
enable_progress_bar=(not self.engine.silent),
deterministic=self.engine.deterministic,
precision=precision_with_fallback(self.engine.precision), # type: ignore
accelerator=self.engine.accelerator
)
def get_best_checkpoint(self) -> Optional[Path]:
model_checkpoint = self._callbacks.get("best_model_checkpoint", None)
if model_checkpoint is not None:
model_checkpoint = Path(model_checkpoint.best_model_path)
model_checkpoint = model_checkpoint if not model_checkpoint.is_dir() else None
if model_checkpoint is None:
available_checkpoints = self.get_available_checkpoints()
model_checkpoint = findfirst(lambda x: x.stem.startswith("best"), available_checkpoints)
return model_checkpoint
def get_available_checkpoints(self) -> list[Path]:
if self.checkpoint_dir.exists():
return list(filter(lambda x: x.suffix == ".ckpt", self.checkpoint_dir.iterdir()))
return []
def ensure_max_steps(self):
"""
Ensures that `self.task.max_steps` is calculated and set correctly.
"""
if self.task.max_steps is None:
max_steps = self._calc_max_steps()
self._config[self.task_key]["max_steps"] = max_steps
if self._default_config[self.task_key]["max_steps"] is None:
self._default_config[self.task_key]["max_steps"] = max_steps
self._generate_configs()
log_info(f"'max_steps' not set explicitly, using {max_steps=} (calculated from " +
f"max_epochs={self.task.max_epochs}, batch_size={self.task.batch_size}, devices={self.engine.devices})")
def _ensure_resume_path(self):
"""
Ensures that `self.engine.resume` is either a valid Path or None.
"""
if isinstance(self.engine.resume, Path):
pass
elif isinstance(self.engine.resume, bool):
resume_path = None
if self.engine.resume:
available_checkpoints = self.get_available_checkpoints()
if len(available_checkpoints) < 1:
log_warn("engine.resume=True but no checkpoint was found. Starting run from scratch.")
else:
resume_path = findfirst(lambda x: x.stem == "last", available_checkpoints)
self._config[self.engine_key]["resume"] = resume_path
self._generate_configs()
else:
raise TypeError(f"Unsupportet type for 'resume', got {type(self.engine.resume)=}.")
def _calc_max_steps(self) -> int:
dm = self.get_datamodule()
dm.setup("fit")
train_samples = len(dm.data_train)
return calculate_steps(self.task.max_epochs, train_samples, self.engine.devices, self.task.batch_size)
def _init_callbacks(self):
self._callbacks["optimizer_time"] = OptimizerTime()
self._callbacks["best_model_checkpoint"] = ModelCheckpoint(
dirpath=self.checkpoint_dir,
filename="best-{epoch}-{step}",
monitor=self.task.target_metric,
mode=self.task.target_metric_mode
)
self._callbacks["model_checkpoint"] = ModelCheckpoint(
dirpath=self.checkpoint_dir,
enable_version_counter=False,
every_n_epochs=1,
save_last=True
)
if self.engine.early_stopping is not None:
self._callbacks["early_stopping"] = EarlyStopping(
monitor=self.engine.early_stopping_metric,
mode=self.task.target_metric_mode,
patience=self.engine.early_stopping,
check_finite=self.engine.check_finite,
log_rank_zero_only=True
)
self._callbacks["lr_monitor"] = LearningRateMonitor(
logging_interval=self.optimizer.lr_interval
)
if self.engine.log_extra:
self._callbacks["extra"] = LogTrainingStats(
log_every_n_steps=self.engine.logging_inteval,
**(self.engine.log_extra if isinstance(self.engine.log_extra, dict) else {})
)
self._callbacks["print_epoch"] = PrintEpochWithTime(self.engine.silent)
if self.engine.restrict_train_epochs is not None:
self._callbacks["restrict_train_epochs"] = RestrictTrainEpochs(self.engine.restrict_train_epochs)
# TODO: callback for logging time per step
def outpath_exclude_keys(self) -> list[str]:
return [
self.eval_key,
"output_dir_name"
]
def _set_outpath(self):
base: Path = self.engine.output_dir / self.task.output_dir_name / self.optimizer.output_dir_name
exclude_keys = self.outpath_exclude_keys()
exclude_keys += self.engine.outpath_irrelevant_engine_keys()
diffs = concatenate_dict_keys(dict_differences(self._config, self._default_config), exclude_keys=exclude_keys)
run_dir = ",".join(f"{k}={str(v)}" for k, v in sorted(diffs.items())) if diffs else "default"
if len(run_dir) > 254: # max file name length
hashdir = hashlib.md5(run_dir.encode()).hexdigest()
log_info(f"folder name {run_dir} is too long, using {hashdir} instead.")
run_dir = hashdir
self.run_dir = base / run_dir
self.checkpoint_dir = self.run_dir / "checkpoints"
def _generate_configs(self):
self.engine = EngineConfig(self._config, self.task_key, self.engine_key)
self.optimizer = OptimizerConfig(self._config, self.optimizer_key, self.task_key, self.identifier_key)
self.task = TaskConfig(self._config, self.task_key, self.engine_key, self.identifier_key)
self.evaluation = EvalConfig(
self._config,
eval_key=self.eval_key,
engine_key=self.engine_key,
ignore_keys=self.engine.outpath_irrelevant_engine_keys(prefix=f"{self.engine_key}.") + [f"{self.optimizer_key}.output_dir_name", f"{self.task_key}.output_dir_name"]
)