import logging
from contextlib import contextmanager
from itertools import groupby
from typing import Dict, List, Optional, Tuple, Union
import numpy
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.mlm_decoder import MLMDecoder
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.espnet_model import ESPnetASRModel
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.text.token_id_converter import TokenIDConverter
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet.nets.beam_search import Hypothesis
from espnet.nets.e2e_asr_common import ErrorCalculator
from espnet.nets.pytorch_backend.maskctc.add_mask_token import mask_uniform
from espnet.nets.pytorch_backend.nets_utils import th_accuracy
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( # noqa: H301
LabelSmoothingLoss,
)
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
[docs]class MaskCTCModel(ESPnetASRModel):
"""Hybrid CTC/Masked LM Encoder-Decoder model (Mask-CTC)"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: MLMDecoder,
ctc: CTC,
joint_network: Optional[torch.nn.Module] = None,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
ignore_id: int = -1,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
sym_mask: str = "<mask>",
extract_feats_in_collect_stats: bool = True,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
postencoder=postencoder,
decoder=decoder,
ctc=ctc,
joint_network=joint_network,
ctc_weight=ctc_weight,
interctc_weight=interctc_weight,
ignore_id=ignore_id,
lsm_weight=lsm_weight,
length_normalized_loss=length_normalized_loss,
report_cer=report_cer,
report_wer=report_wer,
sym_space=sym_space,
sym_blank=sym_blank,
extract_feats_in_collect_stats=extract_feats_in_collect_stats,
)
# Add <mask> and override inherited fields
token_list.append(sym_mask)
vocab_size += 1
self.vocab_size = vocab_size
self.mask_token = vocab_size - 1
self.token_list = token_list.copy()
# MLM loss
del self.criterion_att
self.criterion_mlm = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.error_calculator = None
if report_cer or report_wer:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
[docs] def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# For data-parallel
text = text[:, : text_lengths.max()]
# Define stats to report
loss_mlm, acc_mlm = None, None
loss_ctc, cer_ctc = None, None
stats = dict()
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
# 2. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# 2a. Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
# 3. MLM decoder branch
if self.ctc_weight != 1.0:
loss_mlm, acc_mlm = self._calc_mlm_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 4. CTC/MLM loss definition
if self.ctc_weight == 0.0:
loss = loss_mlm
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_mlm
# Collect MLM branch stats
stats["loss_mlm"] = loss_mlm.detach() if loss_mlm is not None else None
stats["acc_mlm"] = acc_mlm
# Collect total loss stats
stats["loss"] = loss.detach()
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def _calc_mlm_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# 1. Apply masks
ys_in_pad, ys_out_pad = mask_uniform(
ys_pad, self.mask_token, self.eos, self.ignore_id
)
# 2. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_pad_lens
)
# 3. Compute mlm loss
loss_mlm = self.criterion_mlm(decoder_out, ys_out_pad)
acc_mlm = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
return loss_mlm, acc_mlm
[docs] def nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
[docs] def batchify_nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
batch_size: int = 100,
):
raise NotImplementedError
[docs]class MaskCTCInference(torch.nn.Module):
"""Mask-CTC-based non-autoregressive inference"""
def __init__(
self,
asr_model: MaskCTCModel,
n_iterations: int,
threshold_probability: float,
):
"""Initialize Mask-CTC inference"""
super().__init__()
self.ctc = asr_model.ctc
self.mlm = asr_model.decoder
self.mask_token = asr_model.mask_token
self.n_iterations = n_iterations
self.threshold_probability = threshold_probability
self.converter = TokenIDConverter(token_list=asr_model.token_list)
[docs] def ids2text(self, ids: List[int]):
text = "".join(self.converter.ids2tokens(ids))
return text.replace("<mask>", "_").replace("<space>", " ")
[docs] def forward(self, enc_out: torch.Tensor) -> List[Hypothesis]:
"""Perform Mask-CTC inference"""
# greedy ctc outputs
enc_out = enc_out.unsqueeze(0)
ctc_probs, ctc_ids = torch.exp(self.ctc.log_softmax(enc_out)).max(dim=-1)
y_hat = torch.stack([x[0] for x in groupby(ctc_ids[0])])
y_idx = torch.nonzero(y_hat != 0).squeeze(-1)
logging.info("ctc:{}".format(self.ids2text(y_hat[y_idx].tolist())))
# calculate token-level ctc probabilities by taking
# the maximum probability of consecutive frames with
# the same ctc symbols
probs_hat = []
cnt = 0
for i, y in enumerate(y_hat.tolist()):
probs_hat.append(-1)
while cnt < ctc_ids.shape[1] and y == ctc_ids[0][cnt]:
if probs_hat[i] < ctc_probs[0][cnt]:
probs_hat[i] = ctc_probs[0][cnt].item()
cnt += 1
probs_hat = torch.from_numpy(numpy.array(probs_hat)).to(enc_out.device)
# mask ctc outputs based on ctc probabilities
p_thres = self.threshold_probability
mask_idx = torch.nonzero(probs_hat[y_idx] < p_thres).squeeze(-1)
confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1)
mask_num = len(mask_idx)
y_in = (
torch.zeros(1, len(y_idx), dtype=torch.long).to(enc_out.device)
+ self.mask_token
)
y_in[0][confident_idx] = y_hat[y_idx][confident_idx]
logging.info("msk:{}".format(self.ids2text(y_in[0].tolist())))
# iterative decoding
if not mask_num == 0:
K = self.n_iterations
num_iter = K if mask_num >= K and K > 0 else mask_num
for t in range(num_iter - 1):
pred, _ = self.mlm(enc_out, [enc_out.size(1)], y_in, [y_in.size(1)])
pred_score, pred_id = pred[0][mask_idx].max(dim=-1)
cand = torch.topk(pred_score, mask_num // num_iter, -1)[1]
y_in[0][mask_idx[cand]] = pred_id[cand]
mask_idx = torch.nonzero(y_in[0] == self.mask_token).squeeze(-1)
logging.info("msk:{}".format(self.ids2text(y_in[0].tolist())))
# predict leftover masks (|masks| < mask_num // num_iter)
pred, _ = self.mlm(enc_out, [enc_out.size(1)], y_in, [y_in.size(1)])
y_in[0][mask_idx] = pred[0][mask_idx].argmax(dim=-1)
logging.info("msk:{}".format(self.ids2text(y_in[0].tolist())))
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor(
[self.mask_token] + y_in.tolist()[0] + [self.mask_token], device=y_in.device
)
return Hypothesis(yseq=yseq)