mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
119 lines
3.8 KiB
Python
119 lines
3.8 KiB
Python
import importlib
|
|
import time
|
|
from typing import Any, Callable, Optional
|
|
from pathlib import Path
|
|
from lightning import LightningModule, LightningDataModule
|
|
from lightning.pytorch.core.optimizer import LightningOptimizer
|
|
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
|
import torch
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader
|
|
from pytorch_fob.optimizers import Optimizer
|
|
from pytorch_fob.engine.configs import TaskConfig
|
|
from pytorch_fob.engine.parameter_groups import GroupedModel
|
|
|
|
|
|
def import_task(name: str):
|
|
return importlib.import_module(f"pytorch_fob.tasks.{name}.task")
|
|
|
|
|
|
def task_path(name: str) -> Path:
|
|
return Path(__file__).resolve().parent / name
|
|
|
|
|
|
def task_names() -> list[str]:
|
|
EXCLUDE = ["__pycache__"]
|
|
return [d.name for d in Path(__file__).parent.iterdir() if d.is_dir() and d.name not in EXCLUDE]
|
|
|
|
|
|
class TaskModel(LightningModule):
|
|
def __init__(
|
|
self,
|
|
model: nn.Module | GroupedModel,
|
|
optimizer: Optimizer,
|
|
config: TaskConfig,
|
|
**kwargs: Any
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
self.config = config
|
|
self.optimizer = optimizer
|
|
self.model = model if isinstance(model, GroupedModel) else GroupedModel(model)
|
|
self.optimizer_times_ms = []
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.model.forward(*args, **kwargs)
|
|
|
|
def configure_optimizers(self) -> OptimizerLRScheduler:
|
|
return self.optimizer.configure_optimizers(self.model)
|
|
|
|
def optimizer_step(
|
|
self,
|
|
epoch: int,
|
|
batch_idx: int,
|
|
optimizer: torch.optim.Optimizer | LightningOptimizer,
|
|
optimizer_closure: Optional[Callable[[], Any]] = None,
|
|
) -> None:
|
|
start = time.time_ns()
|
|
optimizer.step(closure=optimizer_closure) # type: ignore
|
|
end = time.time_ns()
|
|
duration_ms = (end - start) / 1e6
|
|
self.optimizer_times_ms.append(duration_ms)
|
|
|
|
|
|
class TaskDataModule(LightningDataModule):
|
|
def __init__(self, config: TaskConfig) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.workers: int = min(config.workers, 16)
|
|
self.data_dir: Path = config.data_dir / config.name
|
|
self.batch_size: int = config.batch_size
|
|
self.data_train: Any
|
|
self.data_val: Any
|
|
self.data_test: Any
|
|
self.data_predict: Any
|
|
self.collate_fn = None
|
|
|
|
def check_dataset(self, data):
|
|
"""Make sure that all tasks have correctly configured their data sets"""
|
|
if not data:
|
|
raise NotImplementedError("Each task has its own data set")
|
|
if not self.batch_size or self.batch_size < 1:
|
|
raise NotImplementedError("Each task configures its own batch_size. \
|
|
Please set it explicitely, to avoid confusion.")
|
|
|
|
def train_dataloader(self):
|
|
self.check_dataset(self.data_train)
|
|
return DataLoader(
|
|
self.data_train,
|
|
shuffle=True,
|
|
batch_size=self.batch_size,
|
|
num_workers=self.workers,
|
|
collate_fn=self.collate_fn
|
|
)
|
|
|
|
def val_dataloader(self):
|
|
self.check_dataset(self.data_val)
|
|
return DataLoader(
|
|
self.data_val,
|
|
batch_size=self.batch_size,
|
|
num_workers=self.workers,
|
|
collate_fn=self.collate_fn
|
|
)
|
|
|
|
def test_dataloader(self):
|
|
self.check_dataset(self.data_test)
|
|
return DataLoader(
|
|
self.data_test,
|
|
batch_size=self.batch_size,
|
|
num_workers=self.workers,
|
|
collate_fn=self.collate_fn
|
|
)
|
|
|
|
def predict_dataloader(self):
|
|
self.check_dataset(self.data_predict)
|
|
return DataLoader(
|
|
self.data_predict,
|
|
batch_size=self.batch_size,
|
|
num_workers=self.workers,
|
|
collate_fn=self.collate_fn
|
|
)
|