mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
Replace isort with ruff for import sorting
- Update pre-commit config to use ruff with --select=I for imports only - Apply ruff import sorting to fix pre-commit issues - Ruff and black work together without conflicts
This commit is contained in:
parent
55cdb83cbf
commit
61fdc37f61
13 changed files with 21 additions and 11 deletions
|
|
@ -9,6 +9,7 @@ import torch
|
|||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from neps.utils.common import get_initial_directory, load_lightning_checkpoint
|
||||
|
||||
from pytorch_fob.engine.engine import Engine, Run
|
||||
|
||||
#############################################################
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||
from pprint import pprint
|
||||
|
||||
import yaml
|
||||
|
||||
from pytorch_fob.engine.engine import Engine
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@ import deepspeed
|
|||
import torch
|
||||
from lightning import Callback, LightningModule, Trainer
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only
|
||||
from pytorch_fob.engine.utils import log_debug, log_info, log_warn, seconds_to_str
|
||||
from torch.linalg import vector_norm
|
||||
|
||||
from pytorch_fob.engine.utils import log_debug, log_info, log_warn, seconds_to_str
|
||||
|
||||
|
||||
class RestrictTrainEpochs(Callback):
|
||||
"""Counts number of epochs since start of training and stops if max_epochs is reached."""
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Any, Callable, Iterable, Iterator, Literal, Optional
|
|||
|
||||
from matplotlib.figure import Figure
|
||||
from pandas import DataFrame, concat, json_normalize
|
||||
|
||||
from pytorch_fob.engine.configs import EvalConfig
|
||||
from pytorch_fob.engine.grid_search import grid_search
|
||||
from pytorch_fob.engine.parser import YAMLParser
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
|
||||
from pytorch_fob.engine.utils import log_warn, some
|
||||
from torch import nn
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from pytorch_fob.engine.utils import log_warn, some
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParameterGroup:
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from lightning.pytorch.callbacks import (
|
|||
)
|
||||
from lightning.pytorch.loggers import CSVLogger, Logger, TensorBoardLogger
|
||||
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
|
||||
|
||||
from pytorch_fob.engine.callbacks import (
|
||||
LogTrainingStats,
|
||||
OptimizerTime,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from tempfile import TemporaryDirectory
|
|||
from typing import Any, Iterable, Optional, Sequence
|
||||
|
||||
import yaml
|
||||
|
||||
from pytorch_fob.engine.run import Run
|
||||
from pytorch_fob.engine.slurm import Slurm
|
||||
from pytorch_fob.engine.utils import (
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import matplotlib.pyplot as plt
|
|||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from pytorch_fob.engine.parser import YAMLParser
|
||||
from pytorch_fob.engine.utils import (
|
||||
AttributeDict,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import importlib
|
|||
from pathlib import Path
|
||||
|
||||
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
||||
|
||||
from pytorch_fob.engine.configs import OptimizerConfig
|
||||
from pytorch_fob.engine.parameter_groups import GroupedModel
|
||||
|
||||
|
|
|
|||
|
|
@ -7,11 +7,12 @@ import torch
|
|||
from lightning import LightningDataModule, LightningModule
|
||||
from lightning.pytorch.core.optimizer import LightningOptimizer
|
||||
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_fob.engine.configs import TaskConfig
|
||||
from pytorch_fob.engine.parameter_groups import GroupedModel
|
||||
from pytorch_fob.optimizers import Optimizer
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def import_task(name: str):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue