atropos/environments/optimizer/FOB/pytorch_fob/engine/parameter_groups.py
2025-05-18 16:36:28 -07:00

226 lines
9.2 KiB
Python

from dataclasses import dataclass, field
from typing import Any, Callable, Iterable, Optional
from torch import nn
from torch.nn import Module
from torch.nn.parameter import Parameter
from pytorch_fob.engine.utils import some, log_warn
@dataclass
class ParameterGroup():
named_parameters: dict[str, Parameter]
lr_multiplier: Optional[float] = field(default=None)
weight_decay_multiplier: Optional[float] = field(default=None)
optimizer_kwargs: dict[str, Any] = field(default_factory=dict)
def __and__(self, other) -> "ParameterGroup":
assert isinstance(other, ParameterGroup)
n1 = set(self.named_parameters.keys())
n2 = set(other.named_parameters.keys())
all_params = self.named_parameters | other.named_parameters
n12 = n1 & n2
new_params = {n: all_params[n] for n in n12}
return ParameterGroup(
named_parameters=new_params,
lr_multiplier=some(other.lr_multiplier, default=self.lr_multiplier),
weight_decay_multiplier=some(other.weight_decay_multiplier, default=self.weight_decay_multiplier),
optimizer_kwargs=self.optimizer_kwargs | other.optimizer_kwargs
)
def __len__(self) -> int:
return len(self.named_parameters)
def __bool__(self) -> bool:
return not self.empty()
def empty(self) -> bool:
return len(self.named_parameters) == 0
def to_optimizer_dict(
self,
lr: Optional[float] = None,
weight_decay: Optional[float] = None
) -> dict[str, list[Parameter] | Any]:
names = sorted(self.named_parameters)
d = {
"params": [self.named_parameters[n] for n in names],
"names": names,
**self.optimizer_kwargs
}
if lr is not None:
d["lr"] = self.lr_multiplier * lr if self.lr_multiplier is not None else lr
if weight_decay is not None:
d["weight_decay"] = self.weight_decay_multiplier * weight_decay \
if self.weight_decay_multiplier is not None else weight_decay
return d
class GroupedModel(Module):
"""
Wrapper around a nn.Module to allow specifying different optimizer settings for different parameters.
To use this feature for your task, inherit from this class and override the `parameter_groups` method.
Then simply wrap your model before passing it to the `__init__` method of the `TaskModel` superclass.
"""
def __init__(self, model: Module) -> None:
super().__init__()
self.model = model
def forward(self, *args, **kwargs):
return self.model.forward(*args, **kwargs)
def parameter_groups(self) -> list[ParameterGroup]:
return wd_group_named_parameters(self.model)
def grouped_parameters(
self,
lr: Optional[float] = None,
weight_decay: Optional[float] = None
) -> list[dict[str, list[Parameter] | Any]]:
return [pg.to_optimizer_dict(lr, weight_decay) for pg in self.parameter_groups()]
def merge_parameter_splits(split1: list[ParameterGroup], split2: list[ParameterGroup]) -> list[ParameterGroup]:
"""
Merge two lists of ParameterGroup objects into a single list.
Assumes that both input lists partition the parameters.
"""
groups = []
for pg1 in split1:
for pg2 in split2:
pg12 = pg1 & pg2
if not pg12.empty():
groups.append(pg12)
return groups
def group_named_parameters(
model: Module,
g1_conds: Iterable[Callable] = (lambda *_: True,),
g2_conds: Iterable[Callable] = (lambda *_: True,),
special_conds: Iterable[Callable] = tuple(),
ignore_conds: Iterable[Callable] = tuple(),
g1_kwargs: Optional[dict[str, Any]] = None,
g2_kwargs: Optional[dict[str, Any]] = None,
debug: bool = False
) -> list[ParameterGroup]:
"""
Group named parameters based on specified conditions and return a list of ParameterGroup objects.
Args:
model (Module): The neural network model.
g1_conds (Iterable[Callable]): Conditions for selecting parameters for group 1.
g2_conds (Iterable[Callable]): Conditions for selecting parameters for group 2.
special_conds (Iterable[Callable]): Conditions for selecting special parameters that should not be grouped.
ignore_conds (Iterable[Callable]): Conditions for ignoring parameters (e.g. if they occur in submodules).
g1_kwargs (Optional[dict[str, Any]]): Additional keyword arguments for constructor of group 1.
g2_kwargs (Optional[dict[str, Any]]): Additional keyword arguments for constructor of group 2.
Returns:
List[ParameterGroup]: A list of ParameterGroup objects containing named parameters.
"""
g1_kwargs = g1_kwargs if g1_kwargs is not None else {}
g2_kwargs = g2_kwargs if g2_kwargs is not None else {}
s1 = set()
s2 = set()
special = set()
param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = f"{mn}.{pn}" if mn else pn # full param name
if not p.requires_grad or fpn not in param_dict:
continue # frozen weights
elif any(c(m, p, fpn) for c in ignore_conds):
continue
elif any(c(m, p, fpn) for c in special_conds):
special.add(fpn)
elif any(c(m, p, fpn) for c in g1_conds):
s1.add(fpn)
elif any(c(m, p, fpn) for c in g2_conds):
s2.add(fpn)
elif debug:
log_warn("group_named_parameters: Not using any rule for ", fpn, " in ", type(m))
s1 |= (param_dict.keys() - s2 - special)
# validate that we considered every parameter
inter_params = s1 & s2
union_params = s1 | s2
assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both s1/s2 sets!"
assert len(
param_dict.keys() - special - union_params) == 0, \
f"parameters {str(param_dict.keys() - union_params)} \
were not separated into either s1/s2 set!"
if not s2:
param_groups = [ParameterGroup(
named_parameters=dict(zip(sorted(union_params), (param_dict[pn] for pn in sorted(union_params))))
)]
else:
param_groups = [
ParameterGroup(
named_parameters=dict(zip(sorted(s1), (param_dict[pn] for pn in sorted(s1)))),
**g1_kwargs
),
ParameterGroup(
named_parameters=dict(zip(sorted(s2), (param_dict[pn] for pn in sorted(s2)))),
**g2_kwargs
),
]
return param_groups
def wd_group_named_parameters(model: Module) -> list[ParameterGroup]:
whitelist_weight_modules = (nn.Linear, nn.modules.conv._ConvNd) # pylint: disable=protected-access # noqa
blacklist_weight_modules = (nn.modules.batchnorm._NormBase, # pylint: disable=protected-access # noqa
nn.GroupNorm, nn.LayerNorm,
nn.LocalResponseNorm,
nn.Embedding)
ignore_modules = (nn.Sequential,)
apply_decay_conds = [lambda m, _, pn: pn.endswith('weight') and isinstance(m, whitelist_weight_modules)]
apply_no_decay_conds = [lambda m, _, pn: pn.endswith('bias') or isinstance(m, blacklist_weight_modules)]
special_conds = [lambda m, p, pn: hasattr(p, '_optim')]
ignore_conds = [lambda m, p, pn: isinstance(m, ignore_modules)]
return group_named_parameters(
model,
g1_conds=apply_decay_conds,
g2_conds=apply_no_decay_conds,
special_conds=special_conds,
ignore_conds=ignore_conds,
g2_kwargs={'weight_decay_multiplier': 0.0}
)
def resolve_parameter_dicts(dict1: dict[str, Any], dict2: dict[str, Any]) -> list[dict[str, Any]]:
p1, p2 = dict1["params"], dict2["params"]
n1, n2 = set(dict1["names"]), set(dict2["names"])
n_to_p1 = dict(zip(dict1["names"], dict1["params"]))
n_to_p2 = dict(zip(dict2["names"], dict2["params"]))
assert len(n1) == len(p1)
assert len(n2) == len(p2)
kwarg1 = {k: v for k, v in dict1.items() if k not in ["params", "names"]}
kwarg2 = {k: v for k, v in dict2.items() if k not in ["params", "names"]}
n1_and_n2 = n1 & n2
n1_no_n2 = n1 - n2
n2_no_n1 = n2 - n1
assert n1_and_n2 | n1_no_n2 | n2_no_n1 == n1 | n2
outdict1 = {"params": [n_to_p1[n] for n in sorted(n1_no_n2)],
"names": sorted(n1_no_n2), **kwarg1}
outdict2 = {"params": [n_to_p2[n] for n in sorted(n2_no_n1)],
"names": sorted(n2_no_n1), **kwarg2}
# kwarg2 takes precedence if an arg is present in both dicts:
outdict12 = {"params": [{**n_to_p1, **n_to_p2}[n] for n in sorted(n1_and_n2)],
"names": sorted(n1_and_n2), **kwarg1, **kwarg2}
return [outdict1, outdict2, outdict12]
def intersect_parameter_dicts(dict1: dict[str, Any], dict2: dict[str, Any]) -> Optional[dict[str, Any]]:
d = resolve_parameter_dicts(dict1, dict2)[2]
return d if len(d["params"]) > 0 else None
def merge_parameter_dicts(dict1: dict[str, Any], dict2: dict[str, Any]) -> list[dict[str, Any]]:
d = resolve_parameter_dicts(dict1, dict2)
return list(filter(lambda x: len(x["params"]) > 0, d))