atropos/environments/optimizer/FOB/pytorch_fob/optimizers/optimizers.py
2025-05-18 16:36:28 -07:00

27 lines
940 B
Python

import importlib
from pathlib import Path
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from pytorch_fob.engine.parameter_groups import GroupedModel
from pytorch_fob.engine.configs import OptimizerConfig
def import_optimizer(name: str):
return importlib.import_module(f"pytorch_fob.optimizers.{name}.optimizer")
def optimizer_path(name: str) -> Path:
return Path(__file__).resolve().parent / name
def optimizer_names() -> list[str]:
EXCLUDE = ["__pycache__", "lr_schedulers"]
return [d.name for d in Path(__file__).parent.iterdir() if d.is_dir() and d.name not in EXCLUDE]
class Optimizer():
def __init__(self, config: OptimizerConfig) -> None:
self.config = config
def configure_optimizers(self, model: GroupedModel) -> OptimizerLRScheduler:
optimizer_module = import_optimizer(self.config.name)
return optimizer_module.configure_optimizers(model, self.config)