import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from espnet2.uasr.loss.abs_loss import AbsUASRLoss
[docs]class UASRSmoothnessPenalty(AbsUASRLoss):
"""smoothness penalty for UASR."""
def __init__(
self,
weight: float = 1.0,
reduction: str = "none",
):
super().__init__()
assert check_argument_types()
self.weight = weight
self.reduction = reduction
[docs] def forward(
self,
dense_logits: torch.Tensor,
dense_padding_mask: torch.Tensor,
sample_size: int,
is_discriminative_step: bool,
):
"""Forward.
Args:
dense_logits: output logits of generator
dense_padding_mask: padding mask of logits
sample_size: batch size
is_discriminative_step: Whether is training discriminator
"""
if self.weight > 0 and not is_discriminative_step:
smoothness_penalty = F.mse_loss(
dense_logits[:, :-1], dense_logits[:, 1:], reduction=self.reduction
)
smoothness_penalty[dense_padding_mask[:, 1:]] = 0
smoothness_penalty = smoothness_penalty.mean() * sample_size
return smoothness_penalty
else:
return 0