Source code for espnet2.asr_transducer.encoder.modules.convolution
"""Convolution modules for X-former blocks."""
from typing import Dict, Optional, Tuple
import torch
[docs]class ConformerConvolution(torch.nn.Module):
"""ConformerConvolution module definition.
Args:
channels: The number of channels.
kernel_size: Size of the convolving kernel.
activation: Activation function.
norm_args: Normalization module arguments.
causal: Whether to use causal convolution (set to True if streaming).
"""
def __init__(
self,
channels: int,
kernel_size: int,
activation: torch.nn.Module = torch.nn.ReLU(),
norm_args: Dict = {},
causal: bool = False,
) -> None:
"""Construct an ConformerConvolution object."""
super().__init__()
assert (kernel_size - 1) % 2 == 0
self.kernel_size = kernel_size
self.pointwise_conv1 = torch.nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
)
if causal:
self.lorder = kernel_size - 1
padding = 0
else:
self.lorder = 0
padding = (kernel_size - 1) // 2
self.depthwise_conv = torch.nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
)
self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
self.pointwise_conv2 = torch.nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
)
self.activation = activation
[docs] def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x: ConformerConvolution input sequences. (B, T, D_hidden)
mask: Source mask. (B, T_2)
cache: ConformerConvolution input cache. (1, D_hidden, conv_kernel)
Returns:
x: ConformerConvolution output sequences. (B, ?, D_hidden)
cache: ConformerConvolution output cache. (1, D_hidden, conv_kernel)
"""
x = self.pointwise_conv1(x.transpose(1, 2))
x = torch.nn.functional.glu(x, dim=1)
if mask is not None:
x.masked_fill_(mask.unsqueeze(1).expand_as(x), 0.0)
if self.lorder > 0:
if cache is None:
x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
else:
x = torch.cat([cache, x], dim=2)
cache = x[..., -self.lorder :]
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x).transpose(1, 2)
return x, cache
[docs]class ConvolutionalSpatialGatingUnit(torch.nn.Module):
"""Convolutional Spatial Gating Unit module definition.
Args:
size: Initial size to determine the number of channels.
kernel_size: Size of the convolving kernel.
norm_class: Normalization module class.
norm_args: Normalization module arguments.
dropout_rate: Dropout rate.
causal: Whether to use causal convolution (set to True if streaming).
"""
def __init__(
self,
size: int,
kernel_size: int,
norm_class: torch.nn.Module = torch.nn.LayerNorm,
norm_args: Dict = {},
dropout_rate: float = 0.0,
causal: bool = False,
) -> None:
"""Construct a ConvolutionalSpatialGatingUnit object."""
super().__init__()
channels = size // 2
self.kernel_size = kernel_size
if causal:
self.lorder = kernel_size - 1
padding = 0
else:
self.lorder = 0
padding = (kernel_size - 1) // 2
self.conv = torch.nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
)
self.norm = norm_class(channels, **norm_args)
self.activation = torch.nn.Identity()
self.dropout = torch.nn.Dropout(dropout_rate)
[docs] def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x: ConvolutionalSpatialGatingUnit input sequences. (B, T, D_hidden)
mask: Source mask. (B, T_2)
cache: ConvolutionalSpationGatingUnit input cache.
(1, D_hidden, conv_kernel)
Returns:
x: ConvolutionalSpatialGatingUnit output sequences. (B, ?, D_hidden)
"""
x_r, x_g = x.chunk(2, dim=-1)
x_g = self.norm(x_g).transpose(1, 2)
if mask is not None:
x_g.masked_fill_(mask.unsqueeze(1).expand_as(x_g), 0.0)
if self.lorder > 0:
if cache is None:
x_g = torch.nn.functional.pad(x_g, (self.lorder, 0), "constant", 0.0)
else:
x_g = torch.cat([cache, x_g], dim=2)
cache = x_g[..., -self.lorder :]
x_g = self.conv(x_g).transpose(1, 2)
x = self.dropout(x_r * self.activation(x_g))
return x, cache
[docs]class DepthwiseConvolution(torch.nn.Module):
"""Depth-wise Convolution module definition.
Args:
size: Initial size to determine the number of channels.
kernel_size: Size of the convolving kernel.
causal: Whether to use causal convolution (set to True if streaming).
"""
def __init__(
self,
size: int,
kernel_size: int,
causal: bool = False,
) -> None:
"""Construct a DepthwiseConvolution object."""
super().__init__()
channels = size + size
self.kernel_size = kernel_size
if causal:
self.lorder = kernel_size - 1
padding = 0
else:
self.lorder = 0
padding = (kernel_size - 1) // 2
self.conv = torch.nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
)
[docs] def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x: DepthwiseConvolution input sequences. (B, T, D_hidden)
mask: Source mask. (B, T_2)
cache: DepthwiseConvolution input cache. (1, conv_kernel, D_hidden)
Returns:
x: DepthwiseConvolution output sequences. (B, ?, D_hidden)
"""
x = x.transpose(1, 2)
if mask is not None:
x.masked_fill_(mask.unsqueeze(1).expand_as(x), 0.0)
if self.lorder > 0:
if cache is None:
x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
else:
x = torch.cat([cache, x], dim=2)
cache = x[..., -self.lorder :]
x = self.conv(x).transpose(1, 2)
return x, cache