atropos/environments/optimizer/FOB/pytorch_fob/engine/callbacks.py
2025-05-18 16:36:28 -07:00

272 lines
12 KiB
Python

import math
import time
from typing import Iterable, Optional
import deepspeed
import torch
from lightning import Callback, LightningModule, Trainer
from lightning_utilities.core.rank_zero import rank_zero_only
from torch.linalg import vector_norm
from pytorch_fob.engine.utils import log_debug, log_info, log_warn, seconds_to_str
class RestrictTrainEpochs(Callback):
"""Counts number of epochs since start of training and stops if max_epochs is reached."""
def __init__(self, max_epochs: int):
super().__init__()
self.max_epochs = max_epochs
self.epochs = 0
self.skip_first = False
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
log_debug(f"Training for {self.max_epochs} epochs...")
self.epochs = 0
trainer.should_stop = False
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
if self.skip_first:
self.skip_first = False
else:
self.epochs += 1
log_debug(f"Epoch {self.epochs}/{self.max_epochs}")
# TODO: test for DDP, do we need 'trainer.strategy.reduce_boolean_decision'?
if self.epochs >= self.max_epochs:
log_debug(f"Stopping training after {self.epochs} epochs")
trainer.should_stop = True
def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint):
# checkpoint loads the model at the end of the epoch, so we do not count the first epoch
self.skip_first = True
class OptimizerTime(Callback):
def __init__(self):
super().__init__()
self.total_mean_optimizer_step_time_ms: float = 0.0
self.total_epochs: int = 0
def on_train_epoch_end(self, trainer, pl_module):
if len(pl_module.optimizer_times_ms) == 0:
return
epoch_mean = sum(pl_module.optimizer_times_ms) / len(pl_module.optimizer_times_ms)
pl_module.log("mean_optimizer_step_time_ms", epoch_mean, on_step=False, on_epoch=True, sync_dist=True)
# Update the running mean
self.total_epochs += 1
self.total_mean_optimizer_step_time_ms = (
(self.total_mean_optimizer_step_time_ms * (self.total_epochs - 1)) + epoch_mean
) / self.total_epochs
# Reset the optimizer step times for the next epoch
pl_module.optimizer_times_ms = [] # type: ignore
def state_dict(self) -> dict[str, float | int]:
return {"running_mean": self.total_mean_optimizer_step_time_ms, "total_epochs": self.total_epochs}
def load_state_dict(self, state_dict: dict[str, float | int]):
self.total_mean_optimizer_step_time_ms = state_dict["running_mean"]
self.total_epochs = state_dict["total_epochs"] # type: ignore
class PrintEpochWithTime(Callback):
def __init__(self, active: bool = True):
super().__init__()
self.active: bool = active
self.time: dict[str, Optional[float]]
self.reset_time()
def reset_time(self):
self.time = {"train_start": None, "val_start": None, "val_end": None}
@rank_zero_only
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
if self.active:
self.time["train_start"] = time.time()
@rank_zero_only
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
# need to print here since train epoch ends after validation is done
if self.active and all(v is not None for v in self.time.values()):
max_epochs = pl_module.config.max_epochs
train_time = math.ceil(time.time() - self.time["train_start"]) # type: ignore
val_time = math.ceil(self.time["val_end"] - self.time["val_start"]) # type: ignore
log_info(
f"Finished training epoch {trainer.current_epoch + 1} of {max_epochs}. Time spent: training: {seconds_to_str(train_time - val_time)}, validation: {seconds_to_str(val_time)}, total: {seconds_to_str(train_time)}."
)
self.reset_time()
@rank_zero_only
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
if self.active:
self.time["val_start"] = time.time()
@rank_zero_only
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
if self.active:
self.time["val_end"] = time.time()
def metric_fn(metric: str, v: torch.Tensor, override: Optional[float] = None) -> float:
if override is not None:
return override
match metric:
case "mean":
return v.mean().item()
case "sum":
return v.sum().item()
case "abs_mean":
return v.abs().mean().item()
case "std":
return v.std().item()
case "abs_std":
return v.abs().std().item()
case "min":
return v.min().item()
case "max":
return v.max().item()
case "l1":
return vector_norm(v, ord=1).item()
case "l2":
return vector_norm(v, ord=2).item()
case "sq_mean":
return (v**2).mean().item()
case "sq_sum":
return (v**2).sum().item()
case _:
raise ValueError(f"unknown metric {metric}")
def add_metrics_to_stats(
stats: dict[str, float],
prefix: str,
name: str,
v: torch.Tensor,
metrics: Iterable[str],
override: Optional[float] = None,
):
for metric in metrics:
stats[f"{prefix}/{name}/{metric}"] = metric_fn(metric, v, override=override)
class LogTrainingStats(Callback):
def __init__(
self,
log_gradient: bool = True,
log_params: bool = True,
log_quantiles: bool = False,
log_momentum: bool = False,
log_lrs: bool = True,
log_every_n_steps: int = 50,
change_log_interval_every_n_steps: Optional[int] = None,
log_interval_factor: float = 2.0,
min_log_interval: int = 1,
max_log_interval: Optional[int] = None,
metrics: Iterable[str] = ("mean", "abs_mean", "std", "abs_std", "min", "max", "l1", "l2", "sq_mean"),
):
super().__init__()
self.log_gradient = log_gradient
self.log_params = log_params
self.log_quantiles = log_quantiles
self.log_momentum = log_momentum
self.log_lrs = log_lrs
self.log_every_n_steps = log_every_n_steps
self.change_log_interval_every_n_steps = change_log_interval_every_n_steps
self.log_interval_factor = log_interval_factor
self.min_log_interval = min_log_interval
self.max_log_interval = max_log_interval
self.metrics = metrics
def _check_and_adjust_log_interval(self, trainer: Trainer, pl_module: LightningModule):
if self.change_log_interval_every_n_steps is not None:
if trainer.global_step > 0 and trainer.global_step % self.change_log_interval_every_n_steps == 0:
self.log_every_n_steps = math.ceil(self.log_every_n_steps * self.log_interval_factor)
self.log_every_n_steps = max(self.log_every_n_steps, self.min_log_interval)
if self.max_log_interval is not None:
self.log_every_n_steps = min(self.log_every_n_steps, self.max_log_interval)
pl_module.log("logging_interval", self.log_every_n_steps)
return trainer.global_step % self.log_every_n_steps == 0
@rank_zero_only
def on_before_optimizer_step(self, trainer: Trainer, pl_module: LightningModule, optimizer: torch.optim.Optimizer):
if self._check_and_adjust_log_interval(trainer, pl_module):
stats = {}
q = torch.arange(0.25, 1, 0.25).round(decimals=2).to(trainer.model.device)
for param_group in optimizer.param_groups:
for name, param in zip(param_group["names"], param_group["params"]):
if self.log_params or self.log_lrs:
v_detached = param.detach()
if self.log_params:
if torch.isnan(v_detached).sum() > 0:
log_warn(f"# NaN in param {name}")
if torch.isinf(v_detached).sum() > 0:
log_warn(f"# Inf in param {name}")
add_metrics_to_stats(stats, "param", name, v_detached, self.metrics)
if self.log_quantiles and v_detached.size().numel() < 10000000:
deciles = torch.quantile(v_detached.float(), q, interpolation="linear")
for q_idx, d_val in enumerate(deciles):
stats[f"param/{name}/quantile-{q[q_idx]}"] = d_val.item()
if (self.log_gradient or self.log_lrs) and param.requires_grad:
if trainer.num_devices > 1:
grad_data = deepspeed.utils.safe_get_full_grad(param)
else:
grad_data = param.grad
else:
grad_data = None
if grad_data is not None:
if torch.isnan(grad_data).sum() > 0:
log_warn(f"# NaN in grad {name}")
if torch.isinf(grad_data).sum() > 0:
log_warn(f"# Inf in grad {name}")
if self.log_gradient:
if torch.isnan(grad_data).sum() > 0 or torch.isinf(grad_data).sum() > 0:
add_metrics_to_stats(stats, "grad", name, grad_data, self.metrics, override=-10.0)
if self.log_quantiles and grad_data.size().numel() < 10000000:
for q_idx, _ in enumerate(q):
stats[f"param/{name}/quantile-{q[q_idx]}"] = -10
stats[f"grad/{name}/mean"] = grad_data.mean().item()
if len(grad_data.shape) > 1 or grad_data.shape[0] > 1:
add_metrics_to_stats(stats, "grad", name, grad_data, self.metrics)
if self.log_quantiles and grad_data.size().numel() < 10000000:
deciles = torch.quantile(grad_data.float(), q, interpolation="linear")
for q_idx, d_val in enumerate(deciles):
stats[f"grad/{name}/quantile-{q[q_idx]}"] = d_val.item()
if self.log_lrs:
grad_norm = vector_norm(grad_data)
param_norm = vector_norm(v_detached)
effective_lr = (grad_norm / param_norm).item() if param_norm != 0 else 0.0
stats[f"param/{name}/effective_lr"] = effective_lr
if self.log_momentum or self.log_lrs:
if param in optimizer.state:
state = optimizer.state[param]
else:
state = {}
if self.log_momentum:
if "exp_avg" in state:
moment1 = state["exp_avg"]
elif "momentum_buffer" in state:
moment1 = state["momentum_buffer"]
else:
moment1 = None
if moment1 is not None:
add_metrics_to_stats(stats, "1st_order_momentum", name, moment1, self.metrics)
if "exp_avg_sq" in state:
add_metrics_to_stats(stats, "2nd_order_momentum", name, state["exp_avg_sq"], self.metrics)
if self.log_lrs and "lr" in state:
stats[f"param/{name}/lr"] = state["lr"].item()
if trainer.loggers is not None:
for logger in trainer.loggers:
logger.log_metrics(stats, step=trainer.global_step)