Source code for espnet2.enh.loss.wrappers.pit_solver

from collections import defaultdict
from itertools import permutations

import torch

from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper


[docs]class PITSolver(AbsLossWrapper): def __init__( self, criterion: AbsEnhLoss, weight=1.0, independent_perm=True, flexible_numspk=False, ): """Permutation Invariant Training Solver. Args: criterion (AbsEnhLoss): an instance of AbsEnhLoss weight (float): weight (between 0 and 1) of current loss for multi-task learning. independent_perm (bool): If True, PIT will be performed in forward to find the best permutation; If False, the permutation from the last LossWrapper output will be inherited. NOTE (wangyou): You should be careful about the ordering of loss wrappers defined in the yaml config, if this argument is False. flexible_numspk (bool): If True, num_spk will be taken from inf to handle flexible numbers of speakers. This is because ref may include dummy data in this case. """ super().__init__() self.criterion = criterion self.weight = weight self.independent_perm = independent_perm self.flexible_numspk = flexible_numspk
[docs] def forward(self, ref, inf, others={}): """PITSolver forward. Args: ref (List[torch.Tensor]): [(batch, ...), ...] x n_spk inf (List[torch.Tensor]): [(batch, ...), ...] Returns: loss: (torch.Tensor): minimum loss with the best permutation stats: dict, for collecting training status others: dict, in this PIT solver, permutation order will be returned """ perm = others["perm"] if "perm" in others else None if not self.flexible_numspk: assert len(ref) == len(inf), (len(ref), len(inf)) num_spk = len(ref) else: num_spk = len(inf) stats = defaultdict(list) def pre_hook(func, *args, **kwargs): ret = func(*args, **kwargs) for k, v in getattr(self.criterion, "stats", {}).items(): stats[k].append(v) return ret def pair_loss(permutation): return sum( [ pre_hook(self.criterion, ref[s], inf[t]) for s, t in enumerate(permutation) ] ) / len(permutation) if self.independent_perm or perm is None: # computate permuatation independently device = ref[0].device all_permutations = list(permutations(range(num_spk))) losses = torch.stack([pair_loss(p) for p in all_permutations], dim=1) loss, perm_ = torch.min(losses, dim=1) perm = torch.index_select( torch.tensor(all_permutations, device=device, dtype=torch.long), 0, perm_, ) # remove stats from unused permutations for k, v in stats.items(): # (B, num_spk * len(all_permutations), ...) new_v = torch.stack(v, dim=1) B, L, *rest = new_v.shape assert L == num_spk * len(all_permutations), (L, num_spk) new_v = new_v.view(B, L // num_spk, num_spk, *rest).mean(2) if new_v.dim() > 2: shapes = [1 for _ in rest] perm0 = perm_.view(perm_.shape[0], 1, *shapes).expand(-1, -1, *rest) else: perm0 = perm_.unsqueeze(1) stats[k] = new_v.gather(1, perm0.to(device=new_v.device)).unbind(1) else: loss = torch.tensor( [ torch.tensor( [ pre_hook( self.criterion, ref[s][batch].unsqueeze(0), inf[t][batch].unsqueeze(0), ) for s, t in enumerate(p) ] ).mean() for batch, p in enumerate(perm) ] ) loss = loss.mean() for k, v in stats.items(): stats[k] = torch.stack(v, dim=1).mean() stats[self.criterion.name] = loss.detach() return loss.mean(), dict(stats), {"perm": perm}