import itertools
from collections import defaultdict
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.espnet_model import ESPnetASRModel as SingleESPnetASRModel
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.torch_utils.device_funcs import force_gatherable
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
[docs]class PITLossWrapper(AbsLossWrapper):
def __init__(self, criterion_fn: Callable, num_ref: int):
super().__init__()
self.criterion_fn = criterion_fn
self.num_ref = num_ref
[docs] def forward(
self,
inf: torch.Tensor,
inf_lens: torch.Tensor,
ref: torch.Tensor,
ref_lens: torch.Tensor,
others: Dict = None,
):
"""PITLoss Wrapper function. Similar to espnet2/enh/loss/wrapper/pit_solver.py
Args:
inf: Iterable[torch.Tensor], (batch, num_inf, ...)
inf_lens: Iterable[torch.Tensor], (batch, num_inf, ...)
ref: Iterable[torch.Tensor], (batch, num_ref, ...)
ref_lens: Iterable[torch.Tensor], (batch, num_ref, ...)
permute_inf: If true, permute the inference and inference_lens according to
the optimal permutation.
"""
assert (
self.num_ref
== inf.shape[1]
== inf_lens.shape[1]
== ref.shape[1]
== ref_lens.shape[1]
), (self.num_ref, inf.shape, inf_lens.shape, ref.shape, ref_lens.shape)
all_permutations = torch.as_tensor(
list(itertools.permutations(range(self.num_ref), r=self.num_ref))
)
stats = defaultdict(list)
def pre_hook(func, *args, **kwargs):
ret = func(*args, **kwargs)
for k, v in getattr(self.criterion_fn, "stats", {}).items():
stats[k].append(v)
return ret
def pair_loss(permutation):
return sum(
[
pre_hook(
self.criterion_fn,
inf[:, j],
inf_lens[:, j],
ref[:, i],
ref_lens[:, i],
)
for i, j in enumerate(permutation)
]
) / len(permutation)
losses = torch.stack(
[pair_loss(p) for p in all_permutations], dim=1
) # (batch_size, num_perm)
min_losses, min_ids = torch.min(losses, dim=1)
min_ids = min_ids.cpu() # because all_permutations is a cpu tensor.
opt_perm = all_permutations[min_ids] # (batch_size, num_ref)
# Permute the inf and inf_lens according to the optimal perm
return min_losses.mean(), opt_perm
[docs] @classmethod
def permutate(self, perm, *args):
ret = []
batch_size = None
num_ref = None
for arg in args: # (batch, num_inf, ...)
if batch_size is None:
batch_size, num_ref = arg.shape[:2]
else:
assert torch.Size([batch_size, num_ref]) == arg.shape[:2]
ret.append(
torch.stack(
[arg[torch.arange(batch_size), perm[:, i]] for i in range(num_ref)],
dim=1,
)
)
return ret
[docs]class ESPnetASRModel(SingleESPnetASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: Optional[AbsDecoder],
ctc: CTC,
joint_network: Optional[torch.nn.Module],
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
ignore_id: int = -1,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
# In a regular ESPnet recipe, <sos> and <eos> are both "<sos/eos>"
# Pretrained HF Tokenizer needs custom sym_sos and sym_eos
sym_sos: str = "<sos/eos>",
sym_eos: str = "<sos/eos>",
extract_feats_in_collect_stats: bool = True,
lang_token_id: int = -1,
# num_inf: the number of inferences (= number of outputs of the model)
# num_ref: the number of references (= number of groundtruth seqs)
num_inf: int = 1,
num_ref: int = 1,
):
assert check_argument_types()
assert 0.0 < ctc_weight <= 1.0, ctc_weight
assert interctc_weight == 0.0, "interctc is not supported for multispeaker ASR"
super(ESPnetASRModel, self).__init__(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
postencoder=postencoder,
decoder=decoder,
ctc=ctc,
joint_network=joint_network,
ctc_weight=ctc_weight,
interctc_weight=interctc_weight,
ignore_id=ignore_id,
lsm_weight=lsm_weight,
length_normalized_loss=length_normalized_loss,
report_cer=report_cer,
report_wer=report_wer,
sym_space=sym_space,
sym_blank=sym_blank,
sym_sos=sym_sos,
sym_eos=sym_eos,
extract_feats_in_collect_stats=extract_feats_in_collect_stats,
lang_token_id=lang_token_id,
)
assert num_inf == num_ref, "Current PIT loss wrapper requires num_inf=num_ref"
self.num_inf = num_inf
self.num_ref = num_ref
self.pit_ctc = PITLossWrapper(criterion_fn=self.ctc, num_ref=num_ref)
[docs] def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
kwargs: "utt_id" is among the input.
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text_ref = [text] + [
kwargs["text_spk{}".format(spk + 1)] for spk in range(1, self.num_ref)
]
text_ref_lengths = [text_lengths] + [
kwargs.get("text_spk{}_lengths".format(spk + 1), None)
for spk in range(1, self.num_ref)
]
assert all(ref_lengths.dim() == 1 for ref_lengths in text_ref_lengths), (
ref_lengths.shape for ref_lengths in text_ref_lengths
)
text_lengths = torch.stack(text_ref_lengths, dim=1) # (batch, num_ref)
text_length_max = text_lengths.max()
# pad text sequences of different speakers to the same length
text = torch.stack(
[
torch.nn.functional.pad(
ref, (0, text_length_max - ref.shape[1]), value=self.ignore_id
)
for ref in text_ref
],
dim=1,
) # (batch, num_ref, seq_len)
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
loss_transducer, cer_transducer, wer_transducer = None, None, None
stats = dict()
# 1. CTC branch
if self.ctc_weight != 0.0:
# CTC is computed twice
# This 1st ctc calculation is only used to decide permutation
_, perm = self.pit_ctc(encoder_out, encoder_out_lens, text, text_lengths)
encoder_out, encoder_out_lens = PITLossWrapper.permutate(
perm, encoder_out, encoder_out_lens
)
if text.dim() == 3: # combine all speakers hidden vectors and labels.
encoder_out = encoder_out.reshape(-1, *encoder_out.shape[2:])
encoder_out_lens = encoder_out_lens.reshape(-1)
text = text.reshape(-1, text.shape[-1])
text_lengths = text_lengths.reshape(-1)
# This 2nd ctc calculation is to compute the loss
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
loss_ctc = loss_ctc.sum()
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
if self.use_transducer_decoder:
# 2a. Transducer decoder branch
(
loss_transducer,
cer_transducer,
wer_transducer,
) = self._calc_transducer_loss(
encoder_out,
encoder_out_lens,
text,
)
if loss_ctc is not None:
loss = loss_transducer + (self.ctc_weight * loss_ctc)
else:
loss = loss_transducer
# Collect Transducer branch stats
stats["loss_transducer"] = (
loss_transducer.detach() if loss_transducer is not None else None
)
stats["cer_transducer"] = cer_transducer
stats["wer_transducer"] = wer_transducer
else:
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
# Collect total loss stats
stats["loss"] = loss.detach()
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight