Source code for espnet2.enh.separator.conformer_separator

from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union

import torch
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor

from espnet2.enh.layers.complex_utils import is_complex
from espnet2.enh.separator.abs_separator import AbsSeparator
from espnet.nets.pytorch_backend.conformer.encoder import Encoder as ConformerEncoder
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


[docs]class ConformerSeparator(AbsSeparator): def __init__( self, input_dim: int, num_spk: int = 2, predict_noise: bool = False, adim: int = 384, aheads: int = 4, layers: int = 6, linear_units: int = 1536, positionwise_layer_type: str = "linear", positionwise_conv_kernel_size: int = 1, normalize_before: bool = False, concat_after: bool = False, dropout_rate: float = 0.1, input_layer: str = "linear", positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.1, nonlinear: str = "relu", conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, conformer_enc_kernel_size: int = 7, padding_idx: int = -1, ): """Conformer separator. Args: input_dim: input feature dimension num_spk: number of speakers predict_noise: whether to output the estimated noise signal adim (int): Dimension of attention. aheads (int): The number of heads of multi head attention. linear_units (int): The number of units of position-wise feed forward. layers (int): The number of transformer blocks. dropout_rate (float): Dropout rate. input_layer (Union[str, torch.nn.Module]): Input layer type. attention_dropout_rate (float): Dropout rate in attention. positional_dropout_rate (float): Dropout rate after adding positional encoding. normalize_before (bool): Whether to use layer_norm before the first block. concat_after (bool): Whether to concat attention layer's input and output. if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) conformer_pos_enc_layer_type(str): Encoder positional encoding layer type. conformer_self_attn_layer_type (str): Encoder attention layer type. conformer_activation_type(str): Encoder activation function type. positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. use_macaron_style_in_conformer (bool): Whether to use macaron style for positionwise layer. use_cnn_in_conformer (bool): Whether to use convolution module. conformer_enc_kernel_size(int): Kernerl size of convolution module. padding_idx (int): Padding idx for input_layer=embed. nonlinear: the nonlinear function for mask estimation, select from 'relu', 'tanh', 'sigmoid' """ super().__init__() self._num_spk = num_spk self.predict_noise = predict_noise self.conformer = ConformerEncoder( idim=input_dim, attention_dim=adim, attention_heads=aheads, linear_units=linear_units, num_blocks=layers, dropout_rate=dropout_rate, positional_dropout_rate=positional_dropout_rate, attention_dropout_rate=attention_dropout_rate, input_layer=input_layer, normalize_before=normalize_before, concat_after=concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, padding_idx=padding_idx, ) num_outputs = self.num_spk + 1 if self.predict_noise else self.num_spk self.linear = torch.nn.ModuleList( [torch.nn.Linear(adim, input_dim) for _ in range(num_outputs)] ) if nonlinear not in ("sigmoid", "relu", "tanh"): raise ValueError("Not supporting nonlinear={}".format(nonlinear)) self.nonlinear = { "sigmoid": torch.nn.Sigmoid(), "relu": torch.nn.ReLU(), "tanh": torch.nn.Tanh(), }[nonlinear]
[docs] def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] ilens (torch.Tensor): input lengths [Batch] additional (Dict or None): other data included in model NOTE: not used in this model Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ # if complex spectrum, if is_complex(input): feature = abs(input) else: feature = input # prepare pad_mask for transformer pad_mask = make_non_pad_mask(ilens).unsqueeze(1).to(feature.device) x, ilens = self.conformer(feature, pad_mask) masks = [] for linear in self.linear: y = linear(x) y = self.nonlinear(y) masks.append(y) if self.predict_noise: *masks, mask_noise = masks masked = [input * m for m in masks] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) ) if self.predict_noise: others["noise1"] = input * mask_noise return masked, ilens, others
@property def num_spk(self): return self._num_spk