Source code for espnet2.uasr.loss.pseudo_label_loss

import torch
import torch.nn.functional as F
from typeguard import check_argument_types

from espnet2.uasr.loss.abs_loss import AbsUASRLoss
from espnet2.utils.types import str2bool


[docs]class UASRPseudoLabelLoss(AbsUASRLoss): """auxiliary pseudo label loss for UASR.""" def __init__( self, weight: float = 1.0, input_dim: int = 128, output_dim: int = 64, downsample_rate: int = 2, ignore_index: int = -1, reduction: str = "none", ): super().__init__() assert check_argument_types() self.weight = weight self.input_dim = input_dim self.output_dim = output_dim self.downsample_rate = downsample_rate self.ignore_index = ignore_index self.reduction = reduction if self.weight > 0: self.decoder = torch.nn.Linear(self.input_dim, self.output_dim)
[docs] def forward( self, inter_x: torch.Tensor, pseudo_labels: torch.Tensor, is_discriminative_step: str2bool, ): """Forward. Args: """ if self.weight > 0 and not is_discriminative_step and pseudo_labels is not None: inter_x = self.decoder(inter_x) if self.downsample_rate > 1: pseudo_labels = pseudo_labels[:, :: self.downsample_rate] valid_time_length = min(pseudo_labels.shape[1], inter_x.shape[1]) pseudo_label_loss = F.cross_entropy( inter_x[:, :valid_time_length].transpose(1, 2), pseudo_labels[:, :valid_time_length], ignore_index=self.ignore_index, reduction=self.reduction, ) pseudo_label_loss = pseudo_label_loss.mean() * pseudo_label_loss.shape[0] return pseudo_label_loss else: return 0