# noqa: E501 This code is modified from: https://github.com/HazyResearch/state-spaces/blob/main/src/utils/optim_groups.py
import torch.nn as nn
[docs]def add_optimizer_hooks(
model,
bias_weight_decay=False,
normalization_weight_decay=False,
):
"""Set zero weight decay for some params
Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with
attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False,
for normalization parameters if normalization_weight_decay==False
See: https://discuss.pytorch.org/t/weight-decay-only-for-weights-of-nn-linear-and-nn-conv/114348 # noqa
"""
# Separate out all parameters to those that will and won't experience regularizing
# weight decay
blacklist_weight_modules = (nn.Embedding,)
if not normalization_weight_decay:
blacklist_weight_modules += (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
# Not compatible with Pytorch 1.8.1
# nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d,
nn.GroupNorm,
nn.SyncBatchNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
nn.LayerNorm,
nn.LocalResponseNorm,
)
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
if (
(not bias_weight_decay and pn.endswith("bias"))
or getattr(p, "_no_weight_decay", False)
or isinstance(m, blacklist_weight_modules)
):
setattr(p, "_optim", {"weight_decay": 0.0})