# Copyright 2021 Jiatong Shi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from contextlib import contextmanager
from itertools import permutations
from typing import Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from packaging.version import parse as V
from typeguard import check_argument_types
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.diar.attractor.abs_attractor import AbsAttractor
from espnet2.diar.decoder.abs_decoder import AbsDecoder
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet.nets.pytorch_backend.nets_utils import to_device
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
[docs]class ESPnetDiarizationModel(AbsESPnetModel):
"""Speaker Diarization model
If "attractor" is "None", SA-EEND will be used.
Else if "attractor" is not "None", EEND-EDA will be used.
For the details about SA-EEND and EEND-EDA, refer to the following papers:
SA-EEND: https://arxiv.org/pdf/1909.06247.pdf
EEND-EDA: https://arxiv.org/pdf/2005.09921.pdf, https://arxiv.org/pdf/2106.10654.pdf
"""
def __init__(
self,
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
label_aggregator: torch.nn.Module,
encoder: AbsEncoder,
decoder: AbsDecoder,
attractor: Optional[AbsAttractor],
diar_weight: float = 1.0,
attractor_weight: float = 1.0,
):
assert check_argument_types()
super().__init__()
self.encoder = encoder
self.normalize = normalize
self.frontend = frontend
self.specaug = specaug
self.label_aggregator = label_aggregator
self.diar_weight = diar_weight
self.attractor_weight = attractor_weight
self.attractor = attractor
self.decoder = decoder
if self.attractor is not None:
self.decoder = None
elif self.decoder is not None:
self.num_spk = decoder.num_spk
else:
raise NotImplementedError
[docs] def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor = None,
spk_labels: torch.Tensor = None,
spk_labels_lengths: torch.Tensor = None,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, samples)
speech_lengths: (Batch,) default None for chunk interator,
because the chunk-iterator does not
have the speech_lengths returned.
see in
espnet2/iterators/chunk_iter_factory.py
spk_labels: (Batch, )
kwargs: "utt_id" is among the input.
"""
assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
batch_size = speech.shape[0]
# 1. Encoder
# Use bottleneck_feats if exist. Only for "enh + diar" task.
bottleneck_feats = kwargs.get("bottleneck_feats", None)
bottleneck_feats_lengths = kwargs.get("bottleneck_feats_lengths", None)
encoder_out, encoder_out_lens = self.encode(
speech, speech_lengths, bottleneck_feats, bottleneck_feats_lengths
)
if self.attractor is None:
# 2a. Decoder (baiscally a predction layer after encoder_out)
pred = self.decoder(encoder_out, encoder_out_lens)
else:
# 2b. Encoder Decoder Attractors
# Shuffle the chronological order of encoder_out, then calculate attractor
encoder_out_shuffled = encoder_out.clone()
for i in range(len(encoder_out_lens)):
encoder_out_shuffled[i, : encoder_out_lens[i], :] = encoder_out[
i, torch.randperm(encoder_out_lens[i]), :
]
attractor, att_prob = self.attractor(
encoder_out_shuffled,
encoder_out_lens,
to_device(
self,
torch.zeros(
encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2)
),
),
)
# Remove the final attractor which does not correspond to a speaker
# Then multiply the attractors and encoder_out
pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1))
# 3. Aggregate time-domain labels
spk_labels, spk_labels_lengths = self.label_aggregator(
spk_labels, spk_labels_lengths
)
# If encoder uses conv* as input_layer (i.e., subsampling),
# the sequence length of 'pred' might be slighly less than the
# length of 'spk_labels'. Here we force them to be equal.
length_diff_tolerance = 2
length_diff = spk_labels.shape[1] - pred.shape[1]
if length_diff > 0 and length_diff <= length_diff_tolerance:
spk_labels = spk_labels[:, 0 : pred.shape[1], :]
if self.attractor is None:
loss_pit, loss_att = None, None
loss, perm_idx, perm_list, label_perm = self.pit_loss(
pred, spk_labels, encoder_out_lens
)
else:
loss_pit, perm_idx, perm_list, label_perm = self.pit_loss(
pred, spk_labels, encoder_out_lens
)
loss_att = self.attractor_loss(att_prob, spk_labels)
loss = self.diar_weight * loss_pit + self.attractor_weight * loss_att
(
correct,
num_frames,
speech_scored,
speech_miss,
speech_falarm,
speaker_scored,
speaker_miss,
speaker_falarm,
speaker_error,
) = self.calc_diarization_error(pred, label_perm, encoder_out_lens)
if speech_scored > 0 and num_frames > 0:
sad_mr, sad_fr, mi, fa, cf, acc, der = (
speech_miss / speech_scored,
speech_falarm / speech_scored,
speaker_miss / speaker_scored,
speaker_falarm / speaker_scored,
speaker_error / speaker_scored,
correct / num_frames,
(speaker_miss + speaker_falarm + speaker_error) / speaker_scored,
)
else:
sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0
stats = dict(
loss=loss.detach(),
loss_att=loss_att.detach() if loss_att is not None else None,
loss_pit=loss_pit.detach() if loss_pit is not None else None,
sad_mr=sad_mr,
sad_fr=sad_fr,
mi=mi,
fa=fa,
cf=cf,
acc=acc,
der=der,
)
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
[docs] def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
spk_labels: torch.Tensor = None,
spk_labels_lengths: torch.Tensor = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
[docs] def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
bottleneck_feats: torch.Tensor,
bottleneck_feats_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch,)
bottleneck_feats: (Batch, Length, ...): used for enh + diar
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim)
if bottleneck_feats is None:
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
elif self.frontend is None:
# use only bottleneck feature
encoder_out, encoder_out_lens, _ = self.encoder(
bottleneck_feats, bottleneck_feats_lengths
)
else:
# use both frontend and bottleneck feats
# interpolate (copy) feats frames
# to match the length with bottleneck_feats
feats = F.interpolate(
feats.transpose(1, 2), size=bottleneck_feats.shape[1]
).transpose(1, 2)
# concatenate frontend LMF feature and bottleneck feature
encoder_out, encoder_out_lens, _ = self.encoder(
torch.cat((bottleneck_feats, feats), 2), bottleneck_feats_lengths
)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = speech.shape[0]
speech_lengths = (
speech_lengths
if speech_lengths is not None
else torch.ones(batch_size).int() * speech.shape[1]
)
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
[docs] def pit_loss_single_permute(self, pred, label, length):
bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
mask = self.create_length_mask(length, label.size(1), label.size(2))
loss = bce_loss(pred, label)
loss = loss * mask
loss = torch.sum(torch.mean(loss, dim=2), dim=1)
loss = torch.unsqueeze(loss, dim=1)
return loss
[docs] def pit_loss(self, pred, label, lengths):
# Note (jiatong): Credit to https://github.com/hitachi-speech/EEND
num_output = label.size(2)
permute_list = [np.array(p) for p in permutations(range(num_output))]
loss_list = []
for p in permute_list:
label_perm = label[:, :, p]
loss_perm = self.pit_loss_single_permute(pred, label_perm, lengths)
loss_list.append(loss_perm)
loss = torch.cat(loss_list, dim=1)
min_loss, min_idx = torch.min(loss, dim=1)
loss = torch.sum(min_loss) / torch.sum(lengths.float())
batch_size = len(min_idx)
label_list = []
for i in range(batch_size):
label_list.append(label[i, :, permute_list[min_idx[i]]].data.cpu().numpy())
label_permute = torch.from_numpy(np.array(label_list)).float()
return loss, min_idx, permute_list, label_permute
[docs] def create_length_mask(self, length, max_len, num_output):
batch_size = len(length)
mask = torch.zeros(batch_size, max_len, num_output)
for i in range(batch_size):
mask[i, : length[i], :] = 1
mask = to_device(self, mask)
return mask
[docs] def attractor_loss(self, att_prob, label):
batch_size = len(label)
bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
# create attractor label [1, 1, ..., 1, 0]
# att_label: (Batch, num_spk + 1, 1)
att_label = to_device(self, torch.zeros(batch_size, label.size(2) + 1, 1))
att_label[:, : label.size(2), :] = 1
loss = bce_loss(att_prob, att_label)
loss = torch.mean(torch.mean(loss, dim=1))
return loss
[docs] @staticmethod
def calc_diarization_error(pred, label, length):
# Note (jiatong): Credit to https://github.com/hitachi-speech/EEND
(batch_size, max_len, num_output) = label.size()
# mask the padding part
mask = np.zeros((batch_size, max_len, num_output))
for i in range(batch_size):
mask[i, : length[i], :] = 1
# pred and label have the shape (batch_size, max_len, num_output)
label_np = label.data.cpu().numpy().astype(int)
pred_np = (pred.data.cpu().numpy() > 0).astype(int)
label_np = label_np * mask
pred_np = pred_np * mask
length = length.data.cpu().numpy()
# compute speech activity detection error
n_ref = np.sum(label_np, axis=2)
n_sys = np.sum(pred_np, axis=2)
speech_scored = float(np.sum(n_ref > 0))
speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0)))
speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0)))
# compute speaker diarization error
speaker_scored = float(np.sum(n_ref))
speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0)))
speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0)))
n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2)
speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map))
correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output)
num_frames = np.sum(length)
return (
correct,
num_frames,
speech_scored,
speech_miss,
speech_falarm,
speaker_scored,
speaker_miss,
speaker_falarm,
speaker_error,
)