import torch
from typeguard import check_argument_types
from espnet2.uasr.loss.abs_loss import AbsUASRLoss
from espnet2.utils.types import str2bool
[docs]class UASRPhonemeDiversityLoss(AbsUASRLoss):
"""phoneme diversity loss for UASR."""
def __init__(
self,
weight: float = 1.0,
):
super().__init__()
assert check_argument_types()
self.weight = weight
[docs] def forward(
self, dense_x: torch.Tensor, sample_size: int, is_discriminative_step: str2bool
):
"""Forward.
Args:
dense_x: predicted logits of generated samples
sample_size: batch size
is_dicriminative_step: whether is training discriminator
"""
if self.weight > 0 and not is_discriminative_step:
batch_size, time_length, channel_size = dense_x.shape
avg_probs = torch.softmax(
dense_x.reshape(-1, channel_size).float(), dim=-1
).mean(dim=0)
phoneme_ppl = torch.exp(
-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1)
)
phoneme_diversity_loss = (
(channel_size - phoneme_ppl) / channel_size
) * sample_size
return phoneme_diversity_loss
else:
return 0