Source code for espnet2.asr_transducer.encoder.blocks.ebranchformer

"""E-Branchformer block for Transducer encoder."""

from typing import Dict, Optional, Tuple

import torch


[docs]class EBranchformer(torch.nn.Module): """E-Branchformer module definition. Reference: https://arxiv.org/pdf/2210.00077.pdf Args: block_size: Input/output size. linear_size: Linear layers' hidden size. self_att: Self-attention module instance. feed_forward: Feed-forward module instance. feed_forward_macaron: Feed-forward module instance for macaron network. conv_mod: ConvolutionalSpatialGatingUnit module instance. depthwise_conv_mod: DepthwiseConvolution module instance. norm_class: Normalization class. norm_args: Normalization module arguments. dropout_rate: Dropout rate. """ def __init__( self, block_size: int, linear_size: int, self_att: torch.nn.Module, feed_forward: torch.nn.Module, feed_forward_macaron: torch.nn.Module, conv_mod: torch.nn.Module, depthwise_conv_mod: torch.nn.Module, norm_class: torch.nn.Module = torch.nn.LayerNorm, norm_args: Dict = {}, dropout_rate: float = 0.0, ) -> None: """Construct a E-Branchformer object.""" super().__init__() self.self_att = self_att self.feed_forward = feed_forward self.feed_forward_macaron = feed_forward_macaron self.feed_forward_scale = 0.5 self.conv_mod = conv_mod self.depthwise_conv_mod = depthwise_conv_mod self.channel_proj1 = torch.nn.Sequential( torch.nn.Linear(block_size, linear_size), torch.nn.GELU() ) self.channel_proj2 = torch.nn.Linear(linear_size // 2, block_size) self.merge_proj = torch.nn.Linear((block_size + block_size), block_size) self.norm_self_att = norm_class(block_size, **norm_args) self.norm_feed_forward = norm_class(block_size, **norm_args) self.norm_feed_forward_macaron = norm_class(block_size, **norm_args) self.norm_mlp = norm_class(block_size, **norm_args) self.norm_final = norm_class(block_size, **norm_args) self.dropout = torch.nn.Dropout(dropout_rate) self.block_size = block_size self.linear_size = linear_size self.cache = None
[docs] def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: """Initialize/Reset self-attention and convolution modules cache for streaming. Args: left_context: Number of previous frames the attention module can see in current chunk. device: Device to use for cache tensor. """ self.cache = [ torch.zeros( (1, left_context, self.block_size), device=device, ), torch.zeros( ( 1, self.linear_size // 2, self.conv_mod.kernel_size - 1, ), device=device, ), torch.zeros( ( 1, self.block_size + self.block_size, self.depthwise_conv_mod.kernel_size - 1, ), device=device, ), ]
[docs] def forward( self, x: torch.Tensor, pos_enc: torch.Tensor, mask: torch.Tensor, chunk_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Encode input sequences. Args: x: E-Branchformer input sequences. (B, T, D_block) pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) mask: Source mask. (B, T) chunk_mask: Chunk mask. (T_2, T_2) Returns: x: E-Branchformer output sequences. (B, T, D_block) mask: Source mask. (B, T) pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) """ residual = x x = self.norm_feed_forward_macaron(x) x = residual + self.feed_forward_scale * self.dropout( self.feed_forward_macaron(x) ) x1 = x x2 = x x1 = self.norm_self_att(x1) x1 = self.dropout( self.self_att(x1, x1, x1, pos_enc, mask=mask, chunk_mask=chunk_mask) ) x2 = self.norm_mlp(x2) x2 = self.channel_proj1(x2) x2, _ = self.conv_mod(x2, mask=mask) x2 = self.dropout(self.channel_proj2(x2)) x_concat = torch.cat([x1, x2], dim=-1) x_depth, _ = self.depthwise_conv_mod(x_concat, mask=mask) x = x + self.merge_proj(x_concat + x_depth) residual = x x = self.norm_feed_forward(x) x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x)) x = self.norm_final(x) return x, mask, pos_enc
[docs] def chunk_forward( self, x: torch.Tensor, pos_enc: torch.Tensor, mask: torch.Tensor, left_context: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Encode chunk of input sequence. Args: x: E-Branchformer input sequences. (B, T, D_block) pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) mask: Source mask. (B, T_2) left_context: Number of previous frames the attention module can see in current chunk. Returns: x: E-Branchformer output sequences. (B, T, D_block) pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) """ residual = x x = self.norm_feed_forward_macaron(x) x = residual + self.feed_forward_scale * self.feed_forward_macaron(x) x1 = x x2 = x x1 = self.norm_self_att(x1) if left_context > 0: key = torch.cat([self.cache[0], x1], dim=1) else: key = x1 att_cache = key[:, -left_context:, :] x1 = self.self_att(x1, key, key, pos_enc, mask=mask, left_context=left_context) x2 = self.norm_mlp(x2) x2 = self.channel_proj1(x2) x2, conv_cache = self.conv_mod(x2, cache=self.cache[1]) x2 = self.channel_proj2(x2) x_concat = torch.cat([x1, x2], dim=-1) x_depth, merge_cache = self.depthwise_conv_mod(x_concat, cache=self.cache[2]) x = x + self.merge_proj(x_concat + x_depth) residual = x x = self.norm_feed_forward(x) x = residual + self.feed_forward_scale * self.feed_forward(x) x = self.norm_final(x) self.cache = [att_cache, conv_cache, merge_cache] return x, pos_enc