# This code is derived from https://github.com/HazyResearch/state-spaces
"""Implementations of different types of residual functions."""
import torch
from torch import nn
[docs]class Residual(nn.Module):
"""Residual connection with constant affine weights.
Can simulate standard residual, no residual, and "constant gates".
"""
def __init__(self, i_layer, d_input, d_model, alpha=1.0, beta=1.0):
# print("ConstantResidual extra kwargs", kwargs)
super().__init__()
assert (d_input == d_model) or alpha == 0.0
self.i_layer = i_layer
self.d_input = d_input
self.d_model = d_model
self.alpha = alpha
self.beta = beta
@property
def d_output(self):
return self.d_model
[docs] def forward(self, x, y, transposed):
y = self.beta * y if self.beta != 1.0 else y
return self.alpha * x + y if self.alpha else y
[docs]class Affine(Residual):
"""Residual connection with learnable scalar multipliers on the main branch.
scalar: Single scalar multiplier, or one per dimension
scale, power: Initialize to scale * layer_num**(-power)
"""
def __init__(self, *args, scalar=True, gamma=0.0, **kwargs):
# print("ConstantResidual extra kwargs", kwargs)
super().__init__(*args, **kwargs)
self.scalar = scalar
self.gamma = gamma
c = self.beta * self.i_layer ** (-self.gamma)
d = 1 if self.scalar else self.d_input
self.affine = nn.Parameter(c * torch.ones(d))
[docs] def forward(self, x, y, transposed):
c = self.affine
if transposed:
c = c.unsqueeze(-1)
return self.alpha * x + c * y
[docs]class Feedforward(Residual):
def __init__(self, *args):
# print("Feedforward extra kwargs", kwargs)
super().__init__(*args, alpha=0.0, beta=1.0)
[docs]class Highway(Residual):
def __init__(self, *args, scaling_correction=False, elemwise=False):
super().__init__(*args)
self.scaling_correction = 1.732 if scaling_correction else 1.0
self.elemwise = elemwise
self.Wx = nn.Linear(self.d_input, self.d_input)
if self.elemwise:
self.Wy = nn.Parameter(torch.randn(self.d_input))
else:
self.Wy = nn.Linear(self.d_input, self.d_input)
[docs] def forward(self, x, y, transposed=False):
if self.elemwise:
y = self.Wy * y
else:
y = self.Wy(y)
r = torch.sigmoid(self.Wx(x) + y)
z = self.scaling_correction * (1.0 - r) * x + r * y
return z
[docs]class DecayResidual(Residual):
"""Residual connection that can decay the linear combination depending on depth."""
def __init__(self, *args, power=0.5, l2=True):
# print("DecayResidual extra kwargs", kwargs)
super().__init__(*args)
self.power = power
self.l2 = l2
[docs] def forward(self, x, y, transposed):
beta = self.i_layer ** (-self.power)
if self.l2:
alpha = (1.0 - beta**2) ** 0.5
else:
alpha = 1.0 - beta
return alpha * x + beta * y
registry = {
"F": Feedforward,
"N": Feedforward,
"R": Residual,
"H": Highway,
"D": DecayResidual,
"A": Affine,
"none": Feedforward,
"ff": Feedforward,
"feedforward": Feedforward,
"residual": Residual,
"highway": Highway,
"decay": DecayResidual,
"affine": Affine,
}