# This code is derived from https://github.com/HazyResearch/state-spaces
"""Implementation of FFN block in the style of Transformers."""
from functools import partial
from torch import nn
from espnet2.asr.state_spaces.base import SequenceModule
from espnet2.asr.state_spaces.components import DropoutNd, LinearActivation
[docs]class FF(SequenceModule):
def __init__(
self,
d_input,
expand=2,
d_output=None,
transposed=False,
activation="gelu",
initializer=None,
dropout=0.0,
tie_dropout=False,
):
super().__init__()
self.d_output = d_input if d_output is None else d_output
self.transposed = transposed
d_inner = expand * d_input
linear1 = LinearActivation(
d_input,
d_inner,
transposed=transposed,
activation=activation,
initializer=initializer,
activate=True,
)
dropout_cls = (
partial(DropoutNd, transposed=self.transposed)
if tie_dropout
else nn.Dropout
)
# dropout_cls = nn.Dropout2d if self.transposed else nn.Dropout
drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity()
linear2 = LinearActivation(
d_inner,
self.d_output,
transposed=transposed,
activation=None,
initializer=initializer,
activate=False,
)
self.ff = nn.Sequential(
linear1,
drop,
linear2,
)
[docs] def forward(self, x, *args, **kwargs):
return self.ff(x), None
[docs] def step(self, x, state, **kwargs):
# x: [batch, d_input]
if self.transposed:
# expects: [batch, d_input, seq_len]
return self.ff(x.unsqueeze(-1)).squeeze(-1), state
else:
return self.ff(x), state