"""MLP with convolutional gating (cgMLP) definition.
References:
https://openreview.net/forum?id=RA-zVvZLYIy
https://arxiv.org/abs/2105.08050
"""
import torch
from espnet.nets.pytorch_backend.nets_utils import get_activation
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
[docs]class ConvolutionalSpatialGatingUnit(torch.nn.Module):
"""Convolutional Spatial Gating Unit (CSGU)."""
def __init__(
self,
size: int,
kernel_size: int,
dropout_rate: float,
use_linear_after_conv: bool,
gate_activation: str,
):
super().__init__()
n_channels = size // 2 # split input channels
self.norm = LayerNorm(n_channels)
self.conv = torch.nn.Conv1d(
n_channels,
n_channels,
kernel_size,
1,
(kernel_size - 1) // 2,
groups=n_channels,
)
if use_linear_after_conv:
self.linear = torch.nn.Linear(n_channels, n_channels)
else:
self.linear = None
if gate_activation == "identity":
self.act = torch.nn.Identity()
else:
self.act = get_activation(gate_activation)
self.dropout = torch.nn.Dropout(dropout_rate)
[docs] def espnet_initialization_fn(self):
torch.nn.init.normal_(self.conv.weight, std=1e-6)
torch.nn.init.ones_(self.conv.bias)
if self.linear is not None:
torch.nn.init.normal_(self.linear.weight, std=1e-6)
torch.nn.init.ones_(self.linear.bias)
[docs] def forward(self, x, gate_add=None):
"""Forward method
Args:
x (torch.Tensor): (N, T, D)
gate_add (torch.Tensor): (N, T, D/2)
Returns:
out (torch.Tensor): (N, T, D/2)
"""
x_r, x_g = x.chunk(2, dim=-1)
x_g = self.norm(x_g) # (N, T, D/2)
x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2)
if self.linear is not None:
x_g = self.linear(x_g)
if gate_add is not None:
x_g = x_g + gate_add
x_g = self.act(x_g)
out = x_r * x_g # (N, T, D/2)
out = self.dropout(out)
return out
[docs]class ConvolutionalGatingMLP(torch.nn.Module):
"""Convolutional Gating MLP (cgMLP)."""
def __init__(
self,
size: int,
linear_units: int,
kernel_size: int,
dropout_rate: float,
use_linear_after_conv: bool,
gate_activation: str,
):
super().__init__()
self.channel_proj1 = torch.nn.Sequential(
torch.nn.Linear(size, linear_units), torch.nn.GELU()
)
self.csgu = ConvolutionalSpatialGatingUnit(
size=linear_units,
kernel_size=kernel_size,
dropout_rate=dropout_rate,
use_linear_after_conv=use_linear_after_conv,
gate_activation=gate_activation,
)
self.channel_proj2 = torch.nn.Linear(linear_units // 2, size)
[docs] def forward(self, x, mask):
if isinstance(x, tuple):
xs_pad, pos_emb = x
else:
xs_pad, pos_emb = x, None
xs_pad = self.channel_proj1(xs_pad) # size -> linear_units
xs_pad = self.csgu(xs_pad) # linear_units -> linear_units/2
xs_pad = self.channel_proj2(xs_pad) # linear_units/2 -> size
if pos_emb is not None:
out = (xs_pad, pos_emb)
else:
out = xs_pad
return out