Source code for espnet2.enh.separator.dc_crn_separator

from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union

import torch
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor

from espnet2.enh.layers.complex_utils import is_complex, new_complex_like
from espnet2.enh.layers.dc_crn import DC_CRN
from espnet2.enh.separator.abs_separator import AbsSeparator

EPS = torch.finfo(torch.get_default_dtype()).eps
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


[docs]class DC_CRNSeparator(AbsSeparator): def __init__( self, input_dim: int, num_spk: int = 2, predict_noise: bool = False, input_channels: List = [2, 16, 32, 64, 128, 256], enc_hid_channels: int = 8, enc_kernel_size: Tuple = (1, 3), enc_padding: Tuple = (0, 1), enc_last_kernel_size: Tuple = (1, 4), enc_last_stride: Tuple = (1, 2), enc_last_padding: Tuple = (0, 1), enc_layers: int = 5, skip_last_kernel_size: Tuple = (1, 3), skip_last_stride: Tuple = (1, 1), skip_last_padding: Tuple = (0, 1), glstm_groups: int = 2, glstm_layers: int = 2, glstm_bidirectional: bool = False, glstm_rearrange: bool = False, mode: str = "masking", ref_channel: int = 0, ): """Densely-Connected Convolutional Recurrent Network (DC-CRN) Separator Reference: Deep Learning Based Real-Time Speech Enhancement for Dual-Microphone Mobile Phones; Tan et al., 2020 https://web.cse.ohio-state.edu/~wang.77/papers/TZW.taslp21.pdf Args: input_dim: input feature dimension num_spk: number of speakers predict_noise: whether to output the estimated noise signal input_channels (list): number of input channels for the stacked DenselyConnectedBlock layers Its length should be (`number of DenselyConnectedBlock layers`). enc_hid_channels (int): common number of intermediate channels for all DenselyConnectedBlock of the encoder enc_kernel_size (tuple): common kernel size for all DenselyConnectedBlock of the encoder enc_padding (tuple): common padding for all DenselyConnectedBlock of the encoder enc_last_kernel_size (tuple): common kernel size for the last Conv layer in all DenselyConnectedBlock of the encoder enc_last_stride (tuple): common stride for the last Conv layer in all DenselyConnectedBlock of the encoder enc_last_padding (tuple): common padding for the last Conv layer in all DenselyConnectedBlock of the encoder enc_layers (int): common total number of Conv layers for all DenselyConnectedBlock layers of the encoder skip_last_kernel_size (tuple): common kernel size for the last Conv layer in all DenselyConnectedBlock of the skip pathways skip_last_stride (tuple): common stride for the last Conv layer in all DenselyConnectedBlock of the skip pathways skip_last_padding (tuple): common padding for the last Conv layer in all DenselyConnectedBlock of the skip pathways glstm_groups (int): number of groups in each Grouped LSTM layer glstm_layers (int): number of Grouped LSTM layers glstm_bidirectional (bool): whether to use BLSTM or unidirectional LSTM in Grouped LSTM layers glstm_rearrange (bool): whether to apply the rearrange operation after each grouped LSTM layer output_channels (int): number of output channels (even number) mode (str): one of ("mapping", "masking") "mapping": complex spectral mapping "masking": complex masking ref_channel (int): index of the reference microphone """ super().__init__() self._num_spk = num_spk self.predict_noise = predict_noise self.mode = mode if mode not in ("mapping", "masking"): raise ValueError("mode=%s is not supported" % mode) self.ref_channel = ref_channel num_outputs = self.num_spk + 1 if self.predict_noise else self.num_spk self.dc_crn = DC_CRN( input_dim=input_dim, input_channels=input_channels, enc_hid_channels=enc_hid_channels, enc_kernel_size=enc_kernel_size, enc_padding=enc_padding, enc_last_kernel_size=enc_last_kernel_size, enc_last_stride=enc_last_stride, enc_last_padding=enc_last_padding, enc_layers=enc_layers, skip_last_kernel_size=skip_last_kernel_size, skip_last_stride=skip_last_stride, skip_last_padding=skip_last_padding, glstm_groups=glstm_groups, glstm_layers=glstm_layers, glstm_bidirectional=glstm_bidirectional, glstm_rearrange=glstm_rearrange, output_channels=num_outputs * 2, )
[docs] def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """DC-CRN Separator Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [Batch, T, F] or [Batch, T, C, F] ilens (torch.Tensor): input lengths [Batch,] Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(Batch, T, F), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ assert is_complex(input) is_multichannel = input.ndim == 4 if is_multichannel: feature = torch.cat([input.real, input.imag], dim=2).permute(0, 2, 1, 3) else: feature = torch.stack([input.real, input.imag], dim=1) masks = self.dc_crn(feature) masks = [new_complex_like(input, m.unbind(dim=1)) for m in masks.unbind(dim=2)] if self.predict_noise: *masks, mask_noise = masks if self.mode == "masking": if is_multichannel: masked = [input * m.unsqueeze(2) for m in masks] else: masked = [input * m for m in masks] else: masked = masks if is_multichannel: masks = [m.unsqueeze(2) / (input + EPS) for m in masked] else: masks = [m / (input + EPS) for m in masked] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) ) if self.predict_noise: mask_noise = mask_noise.unsqueeze(2) if is_multichannel else mask_noise if self.mode == "masking": others["noise1"] = input * mask_noise else: others["noise1"] = mask_noise return masked, ilens, others
@property def num_spk(self): return self._num_spk