import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from espnet2.uasr.loss.abs_loss import AbsUASRLoss
from espnet2.utils.types import str2bool
[docs]class UASRDiscriminatorLoss(AbsUASRLoss):
"""discriminator loss for UASR."""
def __init__(
self,
weight: float = 1.0,
smoothing: float = 0.0,
smoothing_one_side: str2bool = False,
reduction: str = "sum",
):
super().__init__()
assert check_argument_types()
self.weight = weight
self.smoothing = smoothing
self.smoothing_one_sided = smoothing_one_side
self.reduction = reduction
[docs] def forward(
self,
dense_y: torch.Tensor,
token_y: torch.Tensor,
is_discriminative_step: str2bool,
):
"""Forward.
Args:
dense_y: predicted logits of generated samples
token_y: predicted logits of real samples
"""
if self.weight > 0:
fake_smooth = self.smoothing
real_smooth = self.smoothing
if self.smoothing_one_sided:
fake_smooth = 0
if is_discriminative_step:
loss_dense = F.binary_cross_entropy_with_logits(
dense_y,
dense_y.new_ones(dense_y.shape) - fake_smooth,
reduction=self.reduction,
)
loss_token = F.binary_cross_entropy_with_logits(
token_y,
token_y.new_zeros(token_y.shape) + real_smooth,
reduction=self.reduction,
)
else:
loss_dense = F.binary_cross_entropy_with_logits(
dense_y,
dense_y.new_zeros(dense_y.shape) + fake_smooth,
reduction=self.reduction,
)
loss_token = None
return loss_dense, loss_token
else:
return 0