# This code is derived from https://github.com/HazyResearch/state-spaces
"""Implements downsampling and upsampling on sequences."""
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import nn
from espnet2.asr.state_spaces.base import SequenceModule
from espnet2.asr.state_spaces.components import LinearActivation
"""Simple pooling functions that just downsample or repeat
stride: Subsample on the layer dimension
expand: Repeat on the feature dimension
"""
[docs]def downsample(x, stride=1, expand=1, transposed=False):
if x is None:
return None
if stride > 1:
assert x.ndim == 3, (
"Downsampling with higher-dimensional inputs is currently not supported."
"It is recommended to use average or spectral pooling instead."
)
if transposed:
x = x[..., 0::stride]
else:
x = x[..., 0::stride, :]
if expand > 1:
if transposed:
x = repeat(x, "b d ... -> b (d e) ...", e=expand)
else:
x = repeat(x, "b ... d -> b ... (d e)", e=expand)
return x
[docs]def upsample(x, stride=1, expand=1, transposed=False):
if x is None:
return None
if expand > 1:
if transposed:
x = reduce(x, "... (d e) l -> ... d l", "mean", e=expand)
else:
x = reduce(x, "... (d e) -> ... d", "mean", e=expand)
if stride > 1:
if transposed:
x = repeat(x, "... l -> ... (l e)", e=stride)
else:
x = repeat(x, "... l d -> ... (l e) d", e=stride)
return x
[docs]class DownSample(SequenceModule):
def __init__(self, d_input, stride=1, expand=1, transposed=True):
super().__init__()
self.d_input = d_input
self.stride = stride
self.expand = expand
self.transposed = transposed
[docs] def forward(self, x):
return downsample(x, self.stride, self.expand, False, self.transposed)
[docs] def step(self, x, state, **kwargs):
if self.stride > 1 or self.expand > 1:
raise NotImplementedError
return x, state
@property
def d_output(self):
return self.d_input * self.expand
[docs]class DownAvgPool(SequenceModule):
def __init__(self, d_input, stride=1, expand=1, transposed=True):
super().__init__()
self.d_input = d_input
self.stride = stride
self.expand = expand
self.transposed = transposed
[docs] def forward(self, x):
if not self.transposed:
x = rearrange(x, "b ... d -> b d ...")
if self.stride > 1:
# einops appears slower than F
if x.ndim == 3:
x = F.avg_pool1d(x, self.stride, self.stride)
elif x.ndim == 4:
x = F.avg_pool2d(x, self.stride, self.stride)
else:
# Reduction string e.g. "b d (l1 2) (l2 2) -> b d l1 l2"
reduce_str = (
"b d "
+ " ".join([f"(l{i} {self.stride})" for i in range(x.ndim - 2)])
+ " -> b d "
+ " ".join([f"l{i}" for i in range(x.ndim - 2)])
)
x = reduce(x, reduce_str, "mean")
if self.expand > 1:
x = repeat(x, "b d ... -> b (d e) ...", e=self.expand)
if not self.transposed:
x = rearrange(x, "b d ... -> b ... d")
return x
[docs] def step(self, x, state, **kwargs):
if self.stride > 1 or self.expand > 1:
raise NotImplementedError
return x, state
@property
def d_output(self):
return self.d_input * self.expand
[docs]class DownSpectralPool(SequenceModule):
def __init__(self, d_input, stride=1, expand=1, transposed=True):
super().__init__()
self.d_input = d_input
self.stride = stride
self.expand = expand
self.transposed = transposed
[docs] def forward(self, x):
"""Forward pass.
x: (B, L..., D)
"""
if not self.transposed:
x = rearrange(x, "b ... d -> b d ...")
shape = x.shape[2:]
x_f = torch.fft.ifftn(x, s=shape)
for axis, l in enumerate(shape):
assert l % self.stride == 0, "input length must be divisible by stride"
new_l = l // self.stride
idx = torch.cat(
[torch.arange(0, new_l - new_l // 2), l + torch.arange(-new_l // 2, 0)]
).to(x_f.device)
x_f = torch.index_select(x_f, 2 + axis, idx)
x = torch.fft.ifftn(x_f, s=[length // self.stride for length in shape])
x = x.real
if self.expand > 1:
x = repeat(x, "b d ... -> b (d e) ...", e=self.expand)
if not self.transposed:
x = rearrange(x, "b d ... -> b ... d")
return x
[docs] def step(self, x, state, **kwargs):
if self.stride > 1 or self.expand > 1:
raise NotImplementedError
return x, state
@property
def d_output(self):
return self.d_input * self.expand
[docs]class UpSample(nn.Module):
def __init__(self, d_input, stride=1, expand=1, transposed=True):
super().__init__()
self.d_input = d_input
self.stride = stride
self.expand = expand
self.transposed = transposed
[docs] def forward(self, x):
return upsample(x, self.stride, self.expand, self.transposed)
@property
def d_output(self):
return self.d_input // self.expand
[docs] def step(self, x, state, **kwargs):
if self.stride > 1 or self.expand > 1:
raise NotImplementedError
return x, state
""" Pooling functions with trainable parameters """
# For the flexible backbone SequenceModel
[docs]class DownLinearPool(SequenceModule):
def __init__(self, d_input, stride=1, expand=1, transposed=True):
super().__init__()
self.d_input = d_input
self.stride = stride
self.expand = expand
self.transposed = transposed
self.linear = LinearActivation(
d_input * stride,
d_input * expand,
transposed=transposed,
)
[docs] def forward(self, x):
if self.transposed:
x = rearrange(x, "... h (l s) -> ... (h s) l", s=self.stride)
else:
x = rearrange(x, "... (l s) h -> ... l (h s)", s=self.stride)
x = self.linear(x)
return x
[docs] def step(self, x, state, **kwargs):
if self.stride > 1 or self.expand > 1:
raise NotImplementedError
return x, state
@property
def d_output(self):
return self.d_input * self.expand
""" Pooling functions with trainable parameters """
[docs]class DownPool2d(SequenceModule):
def __init__(self, d_input, d_output, stride=1, transposed=True, weight_norm=True):
super().__init__()
self.linear = LinearActivation(
d_input,
d_output,
transposed=transposed,
weight_norm=weight_norm,
)
self.pool = (nn.AvgPool2d(kernel_size=stride, stride=stride),)
[docs] def forward(self, x):
if self.transposed:
x = self.pool(x)
# DownLinearPool is used by the registry (for isotropic backbone)
# DownPool is essentially the same as DownLinearPool. These should be consolidated
[docs]class DownPool(SequenceModule):
def __init__(
self,
d_input,
d_output=None,
expand=None,
stride=1,
transposed=True,
weight_norm=True,
initializer=None,
activation=None,
):
super().__init__()
assert (d_output is None) + (expand is None) == 1
if d_output is None:
d_output = d_input * expand
self.d_output = d_output
self.stride = stride
self.transposed = transposed
self.linear = LinearActivation(
d_input * stride,
d_output,
transposed=transposed,
initializer=initializer,
weight_norm=weight_norm,
activation=activation,
activate=True if activation is not None else False,
)
[docs] def forward(self, x):
if self.transposed:
x = rearrange(x, "... h (l s) -> ... (h s) l", s=self.stride)
else:
x = rearrange(x, "... (l s) h -> ... l (h s)", s=self.stride)
x = self.linear(x)
return x, None
[docs] def step(self, x, state, **kwargs):
"""Step one time step as a recurrent model.
x: (..., H)
"""
if x is None:
return None, state
state.append(x)
if len(state) == self.stride:
x = rearrange(torch.stack(state, dim=-1), "... h s -> ... (h s)")
if self.transposed:
x = x.unsqueeze(-1)
x = self.linear(x)
if self.transposed:
x = x.squeeze(-1)
return x, []
else:
return None, state
[docs] def default_state(self, *batch_shape, device=None):
return []
[docs]class UpPool(SequenceModule):
def __init__(
self,
d_input,
d_output,
stride,
transposed=True,
weight_norm=True,
initializer=None,
activation=None,
):
super().__init__()
self.d_input = d_input
self._d_output = d_output
self.stride = stride
self.transposed = transposed
self.linear = LinearActivation(
d_input,
d_output * stride,
transposed=transposed,
initializer=initializer,
weight_norm=weight_norm,
activation=activation,
activate=True if activation is not None else False,
)
[docs] def forward(self, x, skip=None):
x = self.linear(x)
if self.transposed:
x = F.pad(x[..., :-1], (1, 0)) # Shift to ensure causality
x = rearrange(x, "... (h s) l -> ... h (l s)", s=self.stride)
else:
x = F.pad(x[..., :-1, :], (0, 0, 1, 0)) # Shift to ensure causality
x = rearrange(x, "... l (h s) -> ... (l s) h", s=self.stride)
if skip is not None:
x = x + skip
return x, None
[docs] def step(self, x, state, **kwargs):
"""Step one time step as a recurrent model.
x: (..., H)
"""
assert len(state) > 0
y, state = state[0], state[1:]
if len(state) == 0:
assert x is not None
if self.transposed:
x = x.unsqueeze(-1)
x = self.linear(x)
if self.transposed:
x = x.squeeze(-1)
x = rearrange(x, "... (h s) -> ... h s", s=self.stride)
state = list(torch.unbind(x, dim=-1))
else:
assert x is None
return y, state
[docs] def default_state(self, *batch_shape, device=None):
state = torch.zeros(
batch_shape + (self.d_output, self.stride), device=device
) # (batch, h, s)
state = list(torch.unbind(state, dim=-1)) # List of (..., H)
return state
@property
def d_output(self):
return self._d_output
registry = {
"sample": DownSample,
"pool": DownAvgPool,
"linear": DownLinearPool,
"spectral": DownSpectralPool,
}