mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
Convert FOB submodule to regular folder
This commit is contained in:
parent
94f046ad40
commit
94825011a0
74 changed files with 4563 additions and 0 deletions
119
environments/optimizer/FOB/pytorch_fob/tasks/tasks.py
Normal file
119
environments/optimizer/FOB/pytorch_fob/tasks/tasks.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import importlib
|
||||
import time
|
||||
from typing import Any, Callable, Optional
|
||||
from pathlib import Path
|
||||
from lightning import LightningModule, LightningDataModule
|
||||
from lightning.pytorch.core.optimizer import LightningOptimizer
|
||||
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_fob.optimizers import Optimizer
|
||||
from pytorch_fob.engine.configs import TaskConfig
|
||||
from pytorch_fob.engine.parameter_groups import GroupedModel
|
||||
|
||||
|
||||
def import_task(name: str):
|
||||
return importlib.import_module(f"pytorch_fob.tasks.{name}.task")
|
||||
|
||||
|
||||
def task_path(name: str) -> Path:
|
||||
return Path(__file__).resolve().parent / name
|
||||
|
||||
|
||||
def task_names() -> list[str]:
|
||||
EXCLUDE = ["__pycache__"]
|
||||
return [d.name for d in Path(__file__).parent.iterdir() if d.is_dir() and d.name not in EXCLUDE]
|
||||
|
||||
|
||||
class TaskModel(LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module | GroupedModel,
|
||||
optimizer: Optimizer,
|
||||
config: TaskConfig,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.optimizer = optimizer
|
||||
self.model = model if isinstance(model, GroupedModel) else GroupedModel(model)
|
||||
self.optimizer_times_ms = []
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model.forward(*args, **kwargs)
|
||||
|
||||
def configure_optimizers(self) -> OptimizerLRScheduler:
|
||||
return self.optimizer.configure_optimizers(self.model)
|
||||
|
||||
def optimizer_step(
|
||||
self,
|
||||
epoch: int,
|
||||
batch_idx: int,
|
||||
optimizer: torch.optim.Optimizer | LightningOptimizer,
|
||||
optimizer_closure: Optional[Callable[[], Any]] = None,
|
||||
) -> None:
|
||||
start = time.time_ns()
|
||||
optimizer.step(closure=optimizer_closure) # type: ignore
|
||||
end = time.time_ns()
|
||||
duration_ms = (end - start) / 1e6
|
||||
self.optimizer_times_ms.append(duration_ms)
|
||||
|
||||
|
||||
class TaskDataModule(LightningDataModule):
|
||||
def __init__(self, config: TaskConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.workers: int = min(config.workers, 16)
|
||||
self.data_dir: Path = config.data_dir / config.name
|
||||
self.batch_size: int = config.batch_size
|
||||
self.data_train: Any
|
||||
self.data_val: Any
|
||||
self.data_test: Any
|
||||
self.data_predict: Any
|
||||
self.collate_fn = None
|
||||
|
||||
def check_dataset(self, data):
|
||||
"""Make sure that all tasks have correctly configured their data sets"""
|
||||
if not data:
|
||||
raise NotImplementedError("Each task has its own data set")
|
||||
if not self.batch_size or self.batch_size < 1:
|
||||
raise NotImplementedError("Each task configures its own batch_size. \
|
||||
Please set it explicitely, to avoid confusion.")
|
||||
|
||||
def train_dataloader(self):
|
||||
self.check_dataset(self.data_train)
|
||||
return DataLoader(
|
||||
self.data_train,
|
||||
shuffle=True,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.workers,
|
||||
collate_fn=self.collate_fn
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
self.check_dataset(self.data_val)
|
||||
return DataLoader(
|
||||
self.data_val,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.workers,
|
||||
collate_fn=self.collate_fn
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
self.check_dataset(self.data_test)
|
||||
return DataLoader(
|
||||
self.data_test,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.workers,
|
||||
collate_fn=self.collate_fn
|
||||
)
|
||||
|
||||
def predict_dataloader(self):
|
||||
self.check_dataset(self.data_predict)
|
||||
return DataLoader(
|
||||
self.data_predict,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.workers,
|
||||
collate_fn=self.collate_fn
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue