Source code for espnet2.hubert.espnet_model

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Thanks to Abdelrahman Mohamed and Wei-Ning Hsu's help in this implementation,
# Their origial Hubert work is in:
#     Paper: https://arxiv.org/pdf/2106.07447.pdf
#     Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/hubert

import logging
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union

import torch
from packaging.version import parse as V
from typeguard import check_argument_types

from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.hubert.hubert_loss import HubertPretrainLoss
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet.nets.e2e_asr_common import ErrorCalculator

if V(torch.__version__) >= V("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield


[docs]class TorchAudioHubertPretrainModel(AbsESPnetModel): """TorchAudio Hubert Pretrain model""" def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, ignore_id: int = -1, ): assert check_argument_types() super().__init__() self.vocab_size = vocab_size self.ignore_id = ignore_id self.token_list = token_list.copy() self.frontend = frontend self.specaug = specaug self.normalize = normalize self.preencoder = preencoder self.encoder = encoder self.error_calculator = None self.nan_loss_count = 0.0
[docs] def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) kwargs: "utt_id" is among the input. """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) batch_size = speech.shape[0] # for data-parallel text = text[:, : text_lengths.max()] # 1. Encoder logit_m, logit_u, feature_penalty = self.encode( speech, speech_lengths, text, text_lengths ) # 2a. Hubert criterion loss = self._calc_hubert_loss( logit_m, logit_u, feature_penalty, ) if not torch.isinf(loss) and not torch.isnan(loss): pass # logging.warning(f"loss, {loss.item() / logit_m.size(0)}") else: self.nan_loss_count += 1 logging.warning(f"nan_loss_count, {self.nan_loss_count}") # log accuracies of masked and unmasked frames correct_m, count_m = self._compute_correct(logit_m) correct_u, count_u = self._compute_correct(logit_u) stats = dict( loss=loss.detach(), correct_m=correct_m, count_m=count_m, acc_m=correct_m / count_m, correct_u=correct_u, count_u=count_u, acc_u=correct_u / count_u, ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
[docs] def collect_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs, ) -> Dict[str, torch.Tensor]: feats, feats_lengths = self._extract_feats(speech, speech_lengths) return {"feats": feats, "feats_lengths": feats_lengths}
[docs] def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor, y_pad: torch.Tensor, y_pad_length: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) y_pad: (Batch, Length, ...) y_pad_length: (Batch, ) """ with autocast(False): # 1. Extract feats feats, feats_lengths = self._extract_feats(speech, speech_lengths) # 2. Data augmentation if self.specaug is not None and self.training: feats, feats_lengths = self.specaug(feats, feats_lengths) # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) # Pre-encoder, e.g. used for raw input data if self.preencoder is not None: feats, feats_lengths = self.preencoder(feats, feats_lengths) # 4. Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) encoder_out = self.encoder(feats, feats_lengths, y_pad, y_pad_length) return encoder_out
def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: assert speech_lengths.dim() == 1, speech_lengths.shape # for data-parallel speech = speech[:, : speech_lengths.max()] if self.frontend is not None: # Frontend # e.g. STFT and Feature extract # data_loader may send time-domain signal in this case # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) feats, feats_lengths = self.frontend(speech, speech_lengths) else: # No frontend and no feature extract feats, feats_lengths = speech, speech_lengths return feats, feats_lengths def _compute_correct( self, logits, ): if logits.numel() == 0: corr, count = 0, 0 else: assert logits.dim() > 1, logits.shape max = logits.argmax(-1) == 0 min = logits.argmin(-1) == 0 both = max & min corr = max.long().sum().item() - both.long().sum().item() count = max.numel() return corr, count def _calc_hubert_loss( self, logit_m: Optional[torch.Tensor], logit_u: Optional[torch.Tensor], feature_penalty: torch.Tensor, masked_weight: float = 1.0, unmasked_weight: float = 0.0, feature_weight: float = 10.0, reduction: str = "sum", ) -> torch.Tensor: """Compute the cross-entropy loss on HuBERT masked and non-masked logits. Args: logit_m (Tensor or None): The masked logit Tensor of dimension `(masked_frames, final_dim)`. logit_u (Tensor or None): The non-masked logit Tensor of dimension `(unmasked_frames, final_dim)`. feature_penalty (Tensor): The feature mean value for additional penalty loss. masked_weight (float, optional): The weight for masked cross-entropy loss (Default: ``1.0``). unmasked_weight (float, optional): The weight for non-masked cross-entropy loss (Default: ``0.0``). feature_weight (float, optional): The weight for feature penalty loss (Default: ``10.0``). reduction (str, optional): The reduction method for cross-entropy loss (Default: ``"sum"``). Ref: torchaudio: examples/hubert/loss/hubert_loss.py """ loss = feature_penalty * feature_weight * logit_m.shape[0] if logit_m is not None: target_m = torch.zeros( logit_m.shape[0], dtype=torch.long, device=logit_m.device ) loss_m = torch.nn.functional.cross_entropy( logit_m, target_m, reduction=reduction ) loss += loss_m * masked_weight if logit_u is not None: target_u = torch.zeros( logit_u.shape[0], dtype=torch.long, device=logit_m.device ) loss_u = torch.nn.functional.cross_entropy( logit_u, target_u, reduction=reduction ) loss += loss_u * unmasked_weight return loss
[docs]class HubertPretrainModel(AbsESPnetModel): """Hubert Pretrain model""" def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = False, report_wer: bool = False, sym_space: str = "<space>", sym_blank: str = "<blank>", pred_masked_weight: float = 1.0, pred_nomask_weight: float = 0.0, loss_weights: float = 0.0, ): assert check_argument_types() super().__init__() # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 self.vocab_size = vocab_size self.ignore_id = ignore_id self.token_list = token_list.copy() self.frontend = frontend self.specaug = specaug self.normalize = normalize self.preencoder = preencoder self.encoder = encoder self.criterion_hubert = HubertPretrainLoss( pred_masked_weight, pred_nomask_weight, loss_weights, ) self.pred_masked_weight = pred_masked_weight self.pred_nomask_weight = pred_nomask_weight self.loss_weights = loss_weights if report_cer or report_wer: self.error_calculator = ErrorCalculator( token_list, sym_space, sym_blank, report_cer, report_wer ) else: self.error_calculator = None
[docs] def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) kwargs: "utt_id" is among the input. """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) batch_size = speech.shape[0] # for data-parallel text = text[:, : text_lengths.max()] # 1. Encoder encoder_out = self.encode(speech, speech_lengths, text, text_lengths) # 2a. Hubert criterion loss, acc_mask, acc_unmask = self._calc_hubert_loss( encoder_out, ) stats = dict( loss=loss.detach(), acc_mask=acc_mask, acc_unmask=acc_unmask, acc=acc_mask, ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
[docs] def collect_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs, ) -> Dict[str, torch.Tensor]: feats, feats_lengths = self._extract_feats(speech, speech_lengths) return {"feats": feats, "feats_lengths": feats_lengths}
[docs] def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor, y_pad: torch.Tensor, y_pad_length: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) y_pad: (Batch, Length, ...) y_pad_length: (Batch, ) """ with autocast(False): # 1. Extract feats feats, feats_lengths = self._extract_feats(speech, speech_lengths) # 2. Data augmentation if self.specaug is not None and self.training: feats, feats_lengths = self.specaug(feats, feats_lengths) # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) # Pre-encoder, e.g. used for raw input data if self.preencoder is not None: feats, feats_lengths = self.preencoder(feats, feats_lengths) # 4. Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) encoder_out = self.encoder(feats, feats_lengths, y_pad, y_pad_length) if hasattr(self.encoder, "encoder"): logp_m_list = self.encoder.encoder.get_logits(encoder_out, True) assert self.pred_masked_weight == 0 or len(logp_m_list) > 0 logp_u_list = self.encoder.encoder.get_logits(encoder_out, False) assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0 return encoder_out
def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: assert speech_lengths.dim() == 1, speech_lengths.shape # for data-parallel speech = speech[:, : speech_lengths.max()] if self.frontend is not None: # Frontend # e.g. STFT and Feature extract # data_loader may send time-domain signal in this case # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) feats, feats_lengths = self.frontend(speech, speech_lengths) else: # No frontend and no feature extract feats, feats_lengths = speech, speech_lengths return feats, feats_lengths
[docs] def compute_correct( self, logits, ): if logits.numel() == 0: corr, count = 0, 0 else: assert logits.dim() > 1, logits.shape max = logits.argmax(-1) == 0 min = logits.argmin(-1) == 0 both = max & min corr = max.long().sum().item() - both.long().sum().item() count = max.numel() return corr, count
def _calc_hubert_loss( self, encoder_out: Dict[str, torch.Tensor], ): # 1. Compute hubert loss loss, logp_m_list, logp_u_list = self.criterion_hubert( self.encoder.encoder, encoder_out ) corr_masked, count_masked = 0, 0 corr_unmask, count_unmask = 0, 0 with torch.no_grad(): for i, logp_m in enumerate(logp_m_list): corr_m, count_m = self.compute_correct(logp_m) corr_masked += corr_m count_masked += count_m for i, logp_u in enumerate(logp_u_list): corr_u, count_u = self.compute_correct(logp_u) corr_unmask += corr_u count_unmask += count_u acc_m = corr_masked / (count_masked + 1e-10) acc_u = corr_unmask / (count_unmask + 1e-10) return loss, acc_m, acc_u