Source code for espnet2.enh.espnet_model

"""Enhancement model module."""
import contextlib
from typing import Dict, List, Optional, OrderedDict, Tuple

import numpy as np
import torch
from packaging.version import parse as V
from typeguard import check_argument_types

from espnet2.diar.layers.abs_mask import AbsMask
from espnet2.enh.decoder.abs_decoder import AbsDecoder
from espnet2.enh.encoder.abs_encoder import AbsEncoder
from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainLoss
from espnet2.enh.loss.criterions.time_domain import TimeDomainLoss
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper
from espnet2.enh.separator.abs_separator import AbsSeparator
from espnet2.enh.separator.dan_separator import DANSeparator
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")

EPS = torch.finfo(torch.get_default_dtype()).eps


[docs]class ESPnetEnhancementModel(AbsESPnetModel): """Speech enhancement or separation Frontend model""" def __init__( self, encoder: AbsEncoder, separator: AbsSeparator, decoder: AbsDecoder, mask_module: Optional[AbsMask], loss_wrappers: List[AbsLossWrapper], stft_consistency: bool = False, loss_type: str = "mask_mse", mask_type: Optional[str] = None, extract_feats_in_collect_stats: bool = False, ): assert check_argument_types() super().__init__() self.encoder = encoder self.separator = separator self.decoder = decoder self.mask_module = mask_module self.num_spk = separator.num_spk self.num_noise_type = getattr(self.separator, "num_noise_type", 1) self.loss_wrappers = loss_wrappers names = [w.criterion.name for w in self.loss_wrappers] if len(set(names)) != len(names): raise ValueError("Duplicated loss names are not allowed: {}".format(names)) # get mask type for TF-domain models # (only used when loss_type="mask_*") (deprecated, keep for compatibility) self.mask_type = mask_type.upper() if mask_type else None # get loss type for model training (deprecated, keep for compatibility) self.loss_type = loss_type # whether to compute the TF-domain loss while enforcing STFT consistency # (deprecated, keep for compatibility) # NOTE: STFT consistency is now always used for frequency-domain spectrum losses self.stft_consistency = stft_consistency # for multi-channel signal self.ref_channel = getattr(self.separator, "ref_channel", None) if self.ref_channel is None: self.ref_channel = 0 # Used in espnet2/tasks/abs_task.py for determining whether or not to do # collect_feats during collect stats (stage 5). self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
[docs] def forward( self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor = None, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech_mix: (Batch, samples) or (Batch, samples, channels) speech_ref: (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_mix_lengths: (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py kwargs: "utt_id" is among the input. """ # reference speech signal of each speaker assert "speech_ref1" in kwargs, "At least 1 reference signal input is required." speech_ref = [ kwargs.get( f"speech_ref{spk + 1}", torch.zeros_like(kwargs["speech_ref1"]), ) for spk in range(self.num_spk) ] # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_ref = torch.stack(speech_ref, dim=1) if "noise_ref1" in kwargs: # noise signal (optional, required when using beamforming-based # frontend models) noise_ref = [ kwargs["noise_ref{}".format(n + 1)] for n in range(self.num_noise_type) ] # (Batch, num_noise_type, samples) or # (Batch, num_noise_type, samples, channels) noise_ref = torch.stack(noise_ref, dim=1) else: noise_ref = None # dereverberated (noisy) signal # (optional, only used for frontend models with WPE) if "dereverb_ref1" in kwargs: # noise signal (optional, required when using # frontend models with beamformering) dereverb_speech_ref = [ kwargs["dereverb_ref{}".format(n + 1)] for n in range(self.num_spk) if "dereverb_ref{}".format(n + 1) in kwargs ] assert len(dereverb_speech_ref) in (1, self.num_spk), len( dereverb_speech_ref ) # (Batch, N, samples) or (Batch, N, samples, channels) dereverb_speech_ref = torch.stack(dereverb_speech_ref, dim=1) else: dereverb_speech_ref = None batch_size = speech_mix.shape[0] speech_lengths = ( speech_mix_lengths if speech_mix_lengths is not None else torch.ones(batch_size).int().fill_(speech_mix.shape[1]) ) assert speech_lengths.dim() == 1, speech_lengths.shape # Check that batch_size is unified assert speech_mix.shape[0] == speech_ref.shape[0] == speech_lengths.shape[0], ( speech_mix.shape, speech_ref.shape, speech_lengths.shape, ) # for data-parallel speech_ref = speech_ref[..., : speech_lengths.max()].unbind(dim=1) if noise_ref is not None: noise_ref = noise_ref[..., : speech_lengths.max()].unbind(dim=1) if dereverb_speech_ref is not None: dereverb_speech_ref = dereverb_speech_ref[..., : speech_lengths.max()] dereverb_speech_ref = dereverb_speech_ref.unbind(dim=1) additional = {} # Additional data is required in Deep Attractor Network if isinstance(self.separator, DANSeparator): additional["feature_ref"] = [ self.encoder(r, speech_lengths)[0] for r in speech_ref ] speech_mix = speech_mix[:, : speech_lengths.max()] # model forward speech_pre, feature_mix, feature_pre, others = self.forward_enhance( speech_mix, speech_lengths, additional ) # loss computation loss, stats, weight, perm = self.forward_loss( speech_pre, speech_lengths, feature_mix, feature_pre, others, speech_ref, noise_ref, dereverb_speech_ref, ) return loss, stats, weight
[docs] def forward_enhance( self, speech_mix: torch.Tensor, speech_lengths: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: feature_mix, flens = self.encoder(speech_mix, speech_lengths) if self.mask_module is None: feature_pre, flens, others = self.separator(feature_mix, flens, additional) else: # Obtain bottleneck_feats from separator. # This is used for the input of diarization module in "enh + diar" task bottleneck_feats, bottleneck_feats_lengths = self.separator( feature_mix, flens ) if additional.get("num_spk") is not None: feature_pre, flens, others = self.mask_module( feature_mix, flens, bottleneck_feats, additional["num_spk"] ) others["bottleneck_feats"] = bottleneck_feats others["bottleneck_feats_lengths"] = bottleneck_feats_lengths else: feature_pre = None others = { "bottleneck_feats": bottleneck_feats, "bottleneck_feats_lengths": bottleneck_feats_lengths, } if feature_pre is not None: # for models like SVoice that output multiple lists of separated signals pre_is_multi_list = isinstance(feature_pre[0], (list, tuple)) if pre_is_multi_list: speech_pre = [ [self.decoder(p, speech_lengths)[0] for p in ps] for ps in feature_pre ] else: speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in feature_pre] else: # some models (e.g. neural beamformer trained with mask loss) # do not predict time-domain signal in the training stage speech_pre = None return speech_pre, feature_mix, feature_pre, others
[docs] def forward_loss( self, speech_pre: torch.Tensor, speech_lengths: torch.Tensor, feature_mix: torch.Tensor, feature_pre: torch.Tensor, others: OrderedDict, speech_ref: torch.Tensor, noise_ref: torch.Tensor = None, dereverb_speech_ref: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: # for calculating loss on estimated noise signals if getattr(self.separator, "predict_noise", False): assert "noise1" in others, others.keys() if noise_ref is not None and "noise1" in others: for n in range(self.num_noise_type): key = "noise{}".format(n + 1) others[key] = self.decoder(others[key], speech_lengths)[0] # for calculating loss on dereverberated signals if getattr(self.separator, "predict_dereverb", False): assert "dereverb1" in others, others.keys() if dereverb_speech_ref is not None and "dereverb1" in others: for spk in range(self.num_spk): key = "dereverb{}".format(spk + 1) if key in others: others[key] = self.decoder(others[key], speech_lengths)[0] loss = 0.0 stats = {} o = {} perm = None for loss_wrapper in self.loss_wrappers: criterion = loss_wrapper.criterion if getattr(criterion, "only_for_test", False) and self.training: continue if getattr(criterion, "is_noise_loss", False): if noise_ref is None: raise ValueError( "No noise reference for training!\n" 'Please specify "--use_noise_ref true" in run.sh' ) signal_ref = noise_ref signal_pre = [ others["noise{}".format(n + 1)] for n in range(self.num_noise_type) ] elif getattr(criterion, "is_dereverb_loss", False): if dereverb_speech_ref is None: raise ValueError( "No dereverberated reference for training!\n" 'Please specify "--use_dereverb_ref true" in run.sh' ) signal_ref = dereverb_speech_ref signal_pre = [ others["dereverb{}".format(n + 1)] for n in range(self.num_noise_type) if "dereverb{}".format(n + 1) in others ] if len(signal_pre) == 0: signal_pre = None else: signal_ref = speech_ref signal_pre = speech_pre zero_weight = loss_wrapper.weight == 0.0 if isinstance(criterion, TimeDomainLoss): assert signal_pre is not None sref, spre = self._align_ref_pre_channels( signal_ref, signal_pre, ch_dim=2, force_1ch=True ) # for the time domain criterions with torch.no_grad() if zero_weight else contextlib.ExitStack(): l, s, o = loss_wrapper(sref, spre, {**others, **o}) elif isinstance(criterion, FrequencyDomainLoss): sref, spre = self._align_ref_pre_channels( signal_ref, signal_pre, ch_dim=2, force_1ch=False ) # for the time-frequency domain criterions if criterion.compute_on_mask: # compute loss on masks if getattr(criterion, "is_noise_loss", False): tf_ref, tf_pre = self._get_noise_masks( criterion, feature_mix, speech_ref, signal_ref, signal_pre, speech_lengths, others, ) elif getattr(criterion, "is_dereverb_loss", False): tf_ref, tf_pre = self._get_dereverb_masks( criterion, feature_mix, noise_ref, signal_ref, signal_pre, speech_lengths, others, ) else: tf_ref, tf_pre = self._get_speech_masks( criterion, feature_mix, noise_ref, signal_ref, signal_pre, speech_lengths, others, ) else: # compute on spectrum tf_ref = [self.encoder(sr, speech_lengths)[0] for sr in sref] # for models like SVoice that output multiple lists of # separated signals pre_is_multi_list = isinstance(spre[0], (list, tuple)) if pre_is_multi_list: tf_pre = [ [self.encoder(sp, speech_lengths)[0] for sp in ps] for ps in spre ] else: tf_pre = [self.encoder(sp, speech_lengths)[0] for sp in spre] with torch.no_grad() if zero_weight else contextlib.ExitStack(): l, s, o = loss_wrapper(tf_ref, tf_pre, {**others, **o}) else: raise NotImplementedError("Unsupported loss type: %s" % str(criterion)) loss += l * loss_wrapper.weight stats.update(s) if perm is None and "perm" in o: perm = o["perm"] if self.training and isinstance(loss, float): raise AttributeError( "At least one criterion must satisfy: only_for_test=False" ) stats["loss"] = loss.detach() # force_gatherable: to-device and to-tensor if scalar for DataParallel batch_size = speech_ref[0].shape[0] loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight, perm
def _align_ref_pre_channels(self, ref, pre, ch_dim=2, force_1ch=False): if ref is None or pre is None: return ref, pre # NOTE: input must be a list of time-domain signals index = ref[0].new_tensor(self.ref_channel, dtype=torch.long) # for models like SVoice that output multiple lists of separated signals pre_is_multi_list = isinstance(pre[0], (list, tuple)) pre_dim = pre[0][0].dim() if pre_is_multi_list else pre[0].dim() if ref[0].dim() > pre_dim: # multi-channel reference and single-channel output ref = [r.index_select(ch_dim, index).squeeze(ch_dim) for r in ref] elif ref[0].dim() < pre_dim: # single-channel reference and multi-channel output if pre_is_multi_list: pre = [ p.index_select(ch_dim, index).squeeze(ch_dim) for plist in pre for p in plist ] else: pre = [p.index_select(ch_dim, index).squeeze(ch_dim) for p in pre] elif ref[0].dim() == pre_dim == 3 and force_1ch: # multi-channel reference and output ref = [r.index_select(ch_dim, index).squeeze(ch_dim) for r in ref] if pre_is_multi_list: pre = [ p.index_select(ch_dim, index).squeeze(ch_dim) for plist in pre for p in plist ] else: pre = [p.index_select(ch_dim, index).squeeze(ch_dim) for p in pre] return ref, pre def _get_noise_masks( self, criterion, feature_mix, speech_ref, noise_ref, noise_pre, ilens, others ): speech_spec = self.encoder(sum(speech_ref), ilens)[0] masks_ref = criterion.create_mask_label( feature_mix, [self.encoder(nr, ilens)[0] for nr in noise_ref], noise_spec=speech_spec, ) if "mask_noise1" in others: masks_pre = [ others["mask_noise{}".format(n + 1)] for n in range(self.num_noise_type) ] else: assert len(noise_pre) == len(noise_ref), (len(noise_pre), len(noise_ref)) masks_pre = criterion.create_mask_label( feature_mix, [self.encoder(np, ilens)[0] for np in noise_pre], noise_spec=speech_spec, ) return masks_ref, masks_pre def _get_dereverb_masks( self, criterion, feat_mix, noise_ref, dereverb_ref, dereverb_pre, ilens, others ): if noise_ref is not None: noise_spec = self.encoder(sum(noise_ref), ilens)[0] else: noise_spec = None masks_ref = criterion.create_mask_label( feat_mix, [self.encoder(dr, ilens)[0] for dr in dereverb_ref], noise_spec=noise_spec, ) if "mask_dereverb1" in others: masks_pre = [ others["mask_dereverb{}".format(spk + 1)] for spk in range(self.num_spk) if "mask_dereverb{}".format(spk + 1) in others ] assert len(masks_pre) == len(masks_ref), (len(masks_pre), len(masks_ref)) else: assert len(dereverb_pre) == len(dereverb_ref), ( len(dereverb_pre), len(dereverb_ref), ) masks_pre = criterion.create_mask_label( feat_mix, [self.encoder(dp, ilens)[0] for dp in dereverb_pre], noise_spec=noise_spec, ) return masks_ref, masks_pre def _get_speech_masks( self, criterion, feature_mix, noise_ref, speech_ref, speech_pre, ilens, others ): if noise_ref is not None: noise_spec = self.encoder(sum(noise_ref), ilens)[0] else: noise_spec = None masks_ref = criterion.create_mask_label( feature_mix, [self.encoder(sr, ilens)[0] for sr in speech_ref], noise_spec=noise_spec, ) if "mask_spk1" in others: masks_pre = [ others["mask_spk{}".format(spk + 1)] for spk in range(self.num_spk) ] else: masks_pre = criterion.create_mask_label( feature_mix, [self.encoder(sp, ilens)[0] for sp in speech_pre], noise_spec=noise_spec, ) return masks_ref, masks_pre
[docs] @staticmethod def sort_by_perm(nn_output, perm): """Sort the input list of tensors by the specified permutation. Args: nn_output: List[torch.Tensor(Batch, ...)], len(nn_output) == num_spk perm: (Batch, num_spk) or List[torch.Tensor(num_spk)] Returns: nn_output_new: List[torch.Tensor(Batch, ...)] """ if len(nn_output) == 1: return nn_output # (Batch, num_spk, ...) nn_output = torch.stack(nn_output, dim=1) if not isinstance(perm, torch.Tensor): # perm is a list or tuple perm = torch.stack(perm, dim=0) assert nn_output.size(1) == perm.size(1), (nn_output.shape, perm.shape) diff_dim = nn_output.dim() - perm.dim() if diff_dim > 0: perm = perm.view(*perm.shape, *[1 for _ in range(diff_dim)]).expand_as( nn_output ) return torch.gather(nn_output, 1, perm).unbind(dim=1)
[docs] def collect_feats( self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor, **kwargs ) -> Dict[str, torch.Tensor]: # for data-parallel speech_mix = speech_mix[:, : speech_mix_lengths.max()] feats, feats_lengths = speech_mix, speech_mix_lengths return {"feats": feats, "feats_lengths": feats_lengths}