mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
Convert FOB submodule to regular folder
This commit is contained in:
parent
94f046ad40
commit
94825011a0
74 changed files with 4563 additions and 0 deletions
|
|
@ -0,0 +1,27 @@
|
|||
import importlib
|
||||
from pathlib import Path
|
||||
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
||||
from pytorch_fob.engine.parameter_groups import GroupedModel
|
||||
from pytorch_fob.engine.configs import OptimizerConfig
|
||||
|
||||
|
||||
def import_optimizer(name: str):
|
||||
return importlib.import_module(f"pytorch_fob.optimizers.{name}.optimizer")
|
||||
|
||||
|
||||
def optimizer_path(name: str) -> Path:
|
||||
return Path(__file__).resolve().parent / name
|
||||
|
||||
|
||||
def optimizer_names() -> list[str]:
|
||||
EXCLUDE = ["__pycache__", "lr_schedulers"]
|
||||
return [d.name for d in Path(__file__).parent.iterdir() if d.is_dir() and d.name not in EXCLUDE]
|
||||
|
||||
|
||||
class Optimizer():
|
||||
def __init__(self, config: OptimizerConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
def configure_optimizers(self, model: GroupedModel) -> OptimizerLRScheduler:
|
||||
optimizer_module = import_optimizer(self.config.name)
|
||||
return optimizer_module.configure_optimizers(model, self.config)
|
||||
Loading…
Add table
Add a link
Reference in a new issue