Source code for espnet2.hubert.hubert_loss

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# The HubertPretrainLoss Module uses code from Fairseq:
#     https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/hubert_criterion.py
#
# Thanks to Abdelrahman Mohamed and Wei-Ning Hsu's help in this implementation,
# Their origial Hubert work is in:
#     Paper: https://arxiv.org/pdf/2106.07447.pdf
#     Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/hubert

"""Hubert Pretrain Loss module."""

import torch.nn.functional as F
from torch import nn


[docs]class HubertPretrainLoss(nn.Module): """Hubert criterion module. Args: pred_masked_weight: weight for predictive loss for masked frames pred_nomask_weight: weight for predictive loss for unmasked frames loss_weights: weights for additional loss terms (not first one) """ def __init__( self, pred_masked_weight: float = 1.0, pred_nomask_weight: float = 0.0, loss_weights: float = 10.0, ): super(HubertPretrainLoss, self).__init__() self.pred_masked_weight = pred_masked_weight self.pred_nomask_weight = pred_nomask_weight self.loss_weights = loss_weights
[docs] def forward(self, model, enc_outputs, reduce=True): loss = 0.0 sample_size = 0 reduction = "sum" if reduce else "none" loss_m_list = [] logp_m_list = model.get_logits(enc_outputs, True) targ_m_list = model.get_targets(enc_outputs, True) for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)): loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction) loss_m_list.append(loss_m) if self.pred_masked_weight > 0: loss += self.pred_masked_weight * sum(loss_m_list) sample_size += targ_m_list[0].numel() loss_u_list = [] logp_u_list = model.get_logits(enc_outputs, False) targ_u_list = model.get_targets(enc_outputs, False) for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)): loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction) loss_u_list.append(loss_u) if self.pred_nomask_weight > 0: loss += self.pred_nomask_weight * sum(loss_u_list) sample_size += targ_u_list[0].numel() if self.loss_weights > 0: assert hasattr(model, "get_extra_losses") extra_losses, names = model.get_extra_losses(enc_outputs) if isinstance(extra_losses, list): extra_losses = extra_losses[0] names = names[0] else: raise NotImplementedError("only support one extra loss") loss += self.loss_weights * extra_losses.float() * sample_size return loss, logp_m_list, logp_u_list