Source code for espnet2.enh.separator.neural_beamformer

from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union

import torch
from torch_complex.tensor import ComplexTensor

from espnet2.enh.layers.dnn_beamformer import DNN_Beamformer
from espnet2.enh.layers.dnn_wpe import DNN_WPE
from espnet2.enh.separator.abs_separator import AbsSeparator


[docs]class NeuralBeamformer(AbsSeparator): def __init__( self, input_dim: int, num_spk: int = 1, loss_type: str = "mask_mse", # Dereverberation options use_wpe: bool = False, wnet_type: str = "blstmp", wlayers: int = 3, wunits: int = 300, wprojs: int = 320, wdropout_rate: float = 0.0, taps: int = 5, delay: int = 3, use_dnn_mask_for_wpe: bool = True, wnonlinear: str = "crelu", multi_source_wpe: bool = True, wnormalization: bool = False, # Beamformer options use_beamformer: bool = True, bnet_type: str = "blstmp", blayers: int = 3, bunits: int = 300, bprojs: int = 320, badim: int = 320, ref_channel: int = -1, use_noise_mask: bool = True, bnonlinear: str = "sigmoid", beamformer_type: str = "mvdr_souden", rtf_iterations: int = 2, bdropout_rate: float = 0.0, shared_power: bool = True, use_torchaudio_api: bool = False, # For numerical stability diagonal_loading: bool = True, diag_eps_wpe: float = 1e-7, diag_eps_bf: float = 1e-7, mask_flooring: bool = False, flooring_thres_wpe: float = 1e-6, flooring_thres_bf: float = 1e-6, use_torch_solver: bool = True, ): super().__init__() self._num_spk = num_spk self.loss_type = loss_type if loss_type not in ("mask_mse", "spectrum", "spectrum_log", "magnitude"): raise ValueError("Unsupported loss type: %s" % loss_type) self.use_beamformer = use_beamformer self.use_wpe = use_wpe if self.use_wpe: if use_dnn_mask_for_wpe: # Use DNN for power estimation iterations = 1 else: # Performing as conventional WPE, without DNN Estimator iterations = 2 self.wpe = DNN_WPE( wtype=wnet_type, widim=input_dim, wlayers=wlayers, wunits=wunits, wprojs=wprojs, dropout_rate=wdropout_rate, taps=taps, delay=delay, use_dnn_mask=use_dnn_mask_for_wpe, nmask=1 if multi_source_wpe else num_spk, nonlinear=wnonlinear, iterations=iterations, normalization=wnormalization, diagonal_loading=diagonal_loading, diag_eps=diag_eps_wpe, mask_flooring=mask_flooring, flooring_thres=flooring_thres_wpe, use_torch_solver=use_torch_solver, ) else: self.wpe = None self.ref_channel = ref_channel if self.use_beamformer: self.beamformer = DNN_Beamformer( bidim=input_dim, btype=bnet_type, blayers=blayers, bunits=bunits, bprojs=bprojs, num_spk=num_spk, use_noise_mask=use_noise_mask, nonlinear=bnonlinear, dropout_rate=bdropout_rate, badim=badim, ref_channel=ref_channel, beamformer_type=beamformer_type, rtf_iterations=rtf_iterations, btaps=taps, bdelay=delay, diagonal_loading=diagonal_loading, diag_eps=diag_eps_bf, mask_flooring=mask_flooring, flooring_thres=flooring_thres_bf, use_torch_solver=use_torch_solver, use_torchaudio_api=use_torchaudio_api, ) else: self.beamformer = None # share speech powers between WPE and beamforming (wMPDR/WPD) self.shared_power = shared_power and use_wpe
[docs] def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.complex64/ComplexTensor): mixed speech [Batch, Frames, Channel, Freq] ilens (torch.Tensor): input lengths [Batch] additional (Dict or None): other data included in model NOTE: not used in this model Returns: enhanced speech (single-channel): List[torch.complex64/ComplexTensor] output lengths other predcited data: OrderedDict[ 'dereverb1': ComplexTensor(Batch, Frames, Channel, Freq), 'mask_dereverb1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_noise1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_spk1': torch.Tensor(Batch, Frames, Channel, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Channel, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Channel, Freq), ] """ # Shape of input spectrum must be (B, T, F) or (B, T, C, F) assert input.dim() in (3, 4), input.dim() enhanced = input others = OrderedDict() if ( self.training and self.loss_type is not None and self.loss_type.startswith("mask") ): # Only estimating masks during training for saving memory if self.use_wpe: if input.dim() == 3: mask_w, ilens = self.wpe.predict_mask(input.unsqueeze(-2), ilens) mask_w = mask_w.squeeze(-2) elif input.dim() == 4: mask_w, ilens = self.wpe.predict_mask(input, ilens) if mask_w is not None: if isinstance(enhanced, list): # single-source WPE for spk in range(self.num_spk): others["mask_dereverb{}".format(spk + 1)] = mask_w[spk] else: # multi-source WPE others["mask_dereverb1"] = mask_w if self.use_beamformer and input.dim() == 4: others_b, ilens = self.beamformer.predict_mask(input, ilens) for spk in range(self.num_spk): others["mask_spk{}".format(spk + 1)] = others_b[spk] if len(others_b) > self.num_spk: others["mask_noise1"] = others_b[self.num_spk] return None, ilens, others else: powers = None # Performing both mask estimation and enhancement if input.dim() == 3: # single-channel input (B, T, F) if self.use_wpe: enhanced, ilens, mask_w, powers = self.wpe( input.unsqueeze(-2), ilens ) if isinstance(enhanced, list): # single-source WPE enhanced = [enh.squeeze(-2) for enh in enhanced] if mask_w is not None: for spk in range(self.num_spk): key = "dereverb{}".format(spk + 1) others[key] = enhanced[spk] others["mask_" + key] = mask_w[spk].squeeze(-2) else: # multi-source WPE enhanced = enhanced.squeeze(-2) if mask_w is not None: others["dereverb1"] = enhanced others["mask_dereverb1"] = mask_w.squeeze(-2) else: # multi-channel input (B, T, C, F) # 1. WPE if self.use_wpe: enhanced, ilens, mask_w, powers = self.wpe(input, ilens) if mask_w is not None: if isinstance(enhanced, list): # single-source WPE for spk in range(self.num_spk): key = "dereverb{}".format(spk + 1) others[key] = enhanced[spk] others["mask_" + key] = mask_w[spk] else: # multi-source WPE others["dereverb1"] = enhanced others["mask_dereverb1"] = mask_w.squeeze(-2) # 2. Beamformer if self.use_beamformer: if ( not self.beamformer.beamformer_type.startswith("wmpdr") or not self.beamformer.beamformer_type.startswith("wpd") or not self.shared_power or (self.wpe.nmask == 1 and self.num_spk > 1) ): powers = None # enhanced: (B, T, C, F) -> (B, T, F) if isinstance(enhanced, list): # outputs of single-source WPE raise NotImplementedError( "Single-source WPE is not supported with beamformer " "in multi-speaker cases." ) else: # output of multi-source WPE enhanced, ilens, others_b = self.beamformer( enhanced, ilens, powers=powers ) for spk in range(self.num_spk): others["mask_spk{}".format(spk + 1)] = others_b[spk] if len(others_b) > self.num_spk: others["mask_noise1"] = others_b[self.num_spk] if not isinstance(enhanced, list): enhanced = [enhanced] return enhanced, ilens, others
@property def num_spk(self): return self._num_spk