from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.espnet_model import ESPnetASRModel
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.asr.transducer.error_calculator import ErrorCalculatorTransducer
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.slu.postdecoder.abs_postdecoder import AbsPostDecoder
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet.nets.e2e_asr_common import ErrorCalculator
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( # noqa: H301
LabelSmoothingLoss,
)
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
[docs]class ESPnetSLUModel(ESPnetASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
joint_network: Optional[torch.nn.Module],
postdecoder: Optional[AbsPostDecoder] = None,
deliberationencoder: Optional[AbsPostEncoder] = None,
transcript_token_list: Union[Tuple[str, ...], List[str]] = None,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
ignore_id: int = -1,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
two_pass: bool = False,
pre_postencoder_norm: bool = False,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
AbsESPnetModel.__init__(self)
# note that eos is the same as sos (equivalent ID)
self.blank_id = 0
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.interctc_weight = interctc_weight
self.token_list = token_list.copy()
if transcript_token_list is not None:
self.transcript_token_list = transcript_token_list.copy()
self.two_pass = two_pass
self.pre_postencoder_norm = pre_postencoder_norm
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.postencoder = postencoder
self.postdecoder = postdecoder
self.encoder = encoder
if self.postdecoder is not None:
if self.encoder._output_size != self.postdecoder.output_size_dim:
self.uniform_linear = torch.nn.Linear(
self.encoder._output_size, self.postdecoder.output_size_dim
)
self.deliberationencoder = deliberationencoder
# we set self.decoder = None in the CTC mode since
# self.decoder parameters were never used and PyTorch complained
# and threw an Exception in the multi-GPU experiment.
# thanks Jeff Farris for pointing out the issue.
if not hasattr(self.encoder, "interctc_use_conditioning"):
self.encoder.interctc_use_conditioning = False
if self.encoder.interctc_use_conditioning:
self.encoder.conditioning_layer = torch.nn.Linear(
vocab_size, self.encoder.output_size()
)
self.use_transducer_decoder = joint_network is not None
self.error_calculator = None
if self.use_transducer_decoder:
from warprnnt_pytorch import RNNTLoss
self.decoder = decoder
self.joint_network = joint_network
self.criterion_transducer = RNNTLoss(
blank=self.blank_id,
fastemit_lambda=0.0,
)
if report_cer or report_wer:
self.error_calculator_trans = ErrorCalculatorTransducer(
decoder,
joint_network,
token_list,
sym_space,
sym_blank,
report_cer=report_cer,
report_wer=report_wer,
)
else:
self.error_calculator_trans = None
if self.ctc_weight != 0:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
else:
# we set self.decoder = None in the CTC mode since
# self.decoder parameters were never used and PyTorch complained
# and threw an Exception in the multi-GPU experiment.
# thanks Jeff Farris for pointing out the issue.
if ctc_weight == 1.0:
self.decoder = None
else:
self.decoder = decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
if report_cer or report_wer:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
if ctc_weight == 0.0:
self.ctc = None
else:
self.ctc = ctc
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
[docs] def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
transcript: torch.Tensor = None,
transcript_lengths: torch.Tensor = None,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
kwargs: "utt_id" is among the input.
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(
speech, speech_lengths, transcript, transcript_lengths
)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
loss_transducer, cer_transducer, wer_transducer = None, None, None
stats = dict()
# 1. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
if self.use_transducer_decoder:
# 2a. Transducer decoder branch
(
loss_transducer,
cer_transducer,
wer_transducer,
) = self._calc_transducer_loss(
encoder_out,
encoder_out_lens,
text,
)
if loss_ctc is not None:
loss = loss_transducer + (self.ctc_weight * loss_ctc)
else:
loss = loss_transducer
# Collect Transducer branch stats
stats["loss_transducer"] = (
loss_transducer.detach() if loss_transducer is not None else None
)
stats["cer_transducer"] = cer_transducer
stats["wer_transducer"] = wer_transducer
else:
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
# Collect total loss stats
stats["loss"] = loss.detach()
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
[docs] def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
transcript: torch.Tensor = None,
transcript_lengths: torch.Tensor = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
[docs] def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
transcript_pad: torch.Tensor = None,
transcript_pad_lens: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(
feats, feats_lengths, ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.encoder(
feats,
feats_lengths,
)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
# Post-encoder, e.g. NLU
if self.postencoder is not None:
encoder_out, encoder_out_lens = self.postencoder(
encoder_out, encoder_out_lens
)
if self.postdecoder is not None:
if self.encoder._output_size != self.postdecoder.output_size_dim:
encoder_out = self.uniform_linear(encoder_out)
transcript_list = [
" ".join([self.transcript_token_list[int(k)] for k in k1 if k != -1])
for k1 in transcript_pad
]
(
transcript_input_id_features,
transcript_input_mask_features,
transcript_segment_ids_feature,
transcript_position_ids_feature,
input_id_length,
) = self.postdecoder.convert_examples_to_features(transcript_list, 128)
bert_encoder_out = self.postdecoder(
torch.LongTensor(transcript_input_id_features).to(device=speech.device),
torch.LongTensor(transcript_input_mask_features).to(
device=speech.device
),
torch.LongTensor(transcript_segment_ids_feature).to(
device=speech.device
),
torch.LongTensor(transcript_position_ids_feature).to(
device=speech.device
),
)
bert_encoder_lens = torch.LongTensor(input_id_length).to(
device=speech.device
)
bert_encoder_out = bert_encoder_out[:, : torch.max(bert_encoder_lens)]
final_encoder_out_lens = encoder_out_lens + bert_encoder_lens
max_lens = torch.max(final_encoder_out_lens)
encoder_new_out = torch.zeros(
(encoder_out.shape[0], max_lens, encoder_out.shape[2])
).to(device=speech.device)
for k in range(len(encoder_out)):
encoder_new_out[k] = torch.cat(
(
encoder_out[k, : encoder_out_lens[k]],
bert_encoder_out[k, : bert_encoder_lens[k]],
torch.zeros(
(max_lens - final_encoder_out_lens[k], encoder_out.shape[2])
).to(device=speech.device),
),
0,
)
if self.deliberationencoder is not None:
encoder_new_out, final_encoder_out_lens = self.deliberationencoder(
encoder_new_out, final_encoder_out_lens
)
encoder_out = encoder_new_out
encoder_out_lens = final_encoder_out_lens
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, encoder_out_lens