Source code for espnet2.gan_tts.jets.loss

# Copyright 2022 Dan Lim
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""JETS related loss module for ESPnet2."""

from typing import Tuple

import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import betabinom
from typeguard import check_argument_types

from espnet.nets.pytorch_backend.fastspeech.duration_predictor import (  # noqa: H301
    DurationPredictorLoss,
)
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask


[docs]class VarianceLoss(torch.nn.Module): def __init__(self, use_masking: bool = True, use_weighted_masking: bool = False): """Initialize JETS variance loss module. Args: use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to weighted masking in loss calculation. """ assert check_argument_types() super().__init__() assert (use_masking != use_weighted_masking) or not use_masking self.use_masking = use_masking self.use_weighted_masking = use_weighted_masking # define criterions reduction = "none" if self.use_weighted_masking else "mean" self.mse_criterion = torch.nn.MSELoss(reduction=reduction) self.duration_criterion = DurationPredictorLoss(reduction=reduction)
[docs] def forward( self, d_outs: torch.Tensor, ds: torch.Tensor, p_outs: torch.Tensor, ps: torch.Tensor, e_outs: torch.Tensor, es: torch.Tensor, ilens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Calculate forward propagation. Args: d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text). ds (LongTensor): Batch of durations (B, T_text). p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1). ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1). e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1). es (Tensor): Batch of target token-averaged energy (B, T_text, 1). ilens (LongTensor): Batch of the lengths of each input (B,). Returns: Tensor: Duration predictor loss value. Tensor: Pitch predictor loss value. Tensor: Energy predictor loss value. """ # apply mask to remove padded part if self.use_masking: duration_masks = make_non_pad_mask(ilens).to(ds.device) d_outs = d_outs.masked_select(duration_masks) ds = ds.masked_select(duration_masks) pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ds.device) p_outs = p_outs.masked_select(pitch_masks) e_outs = e_outs.masked_select(pitch_masks) ps = ps.masked_select(pitch_masks) es = es.masked_select(pitch_masks) # calculate loss duration_loss = self.duration_criterion(d_outs, ds) pitch_loss = self.mse_criterion(p_outs, ps) energy_loss = self.mse_criterion(e_outs, es) # make weighted mask and apply it if self.use_weighted_masking: duration_masks = make_non_pad_mask(ilens).to(ds.device) duration_weights = ( duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float() ) duration_weights /= ds.size(0) # apply weight duration_loss = ( duration_loss.mul(duration_weights).masked_select(duration_masks).sum() ) pitch_masks = duration_masks.unsqueeze(-1) pitch_weights = duration_weights.unsqueeze(-1) pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_masks).sum() energy_loss = ( energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum() ) return duration_loss, pitch_loss, energy_loss
[docs]class ForwardSumLoss(torch.nn.Module): """Forwardsum loss described at https://openreview.net/forum?id=0NQwnnwAORi""" def __init__(self, cache_prior: bool = True): """Initialize forwardsum loss module. Args: cache_prior (bool): Whether to cache beta-binomial prior """ super().__init__() self.cache_prior = cache_prior self._cache = {}
[docs] def forward( self, log_p_attn: torch.Tensor, ilens: torch.Tensor, olens: torch.Tensor, blank_prob: float = np.e**-1, ) -> torch.Tensor: """Calculate forward propagation. Args: log_p_attn (Tensor): Batch of log probability of attention matrix (B, T_feats, T_text). ilens (Tensor): Batch of the lengths of each input (B,). olens (Tensor): Batch of the lengths of each target (B,). blank_prob (float): Blank symbol probability. Returns: Tensor: forwardsum loss value. """ B = log_p_attn.size(0) # add beta-binomial prior bb_prior = self._generate_prior(ilens, olens) bb_prior = bb_prior.to(dtype=log_p_attn.dtype, device=log_p_attn.device) log_p_attn = log_p_attn + bb_prior # a row must be added to the attention matrix to account for # blank token of CTC loss # (B,T_feats,T_text+1) log_p_attn_pd = F.pad(log_p_attn, (1, 0, 0, 0, 0, 0), value=np.log(blank_prob)) loss = 0 for bidx in range(B): # construct target sequnece. # Every text token is mapped to a unique sequnece number. target_seq = torch.arange(1, ilens[bidx] + 1).unsqueeze(0) cur_log_p_attn_pd = log_p_attn_pd[ bidx, : olens[bidx], : ilens[bidx] + 1 ].unsqueeze( 1 ) # (T_feats,1,T_text+1) loss += F.ctc_loss( log_probs=cur_log_p_attn_pd, targets=target_seq, input_lengths=olens[bidx : bidx + 1], target_lengths=ilens[bidx : bidx + 1], zero_infinity=True, ) loss = loss / B return loss
def _generate_prior(self, text_lengths, feats_lengths, w=1) -> torch.Tensor: """Generate alignment prior formulated as beta-binomial distribution Args: text_lengths (Tensor): Batch of the lengths of each input (B,). feats_lengths (Tensor): Batch of the lengths of each target (B,). w (float): Scaling factor; lower -> wider the width. Returns: Tensor: Batched 2d static prior matrix (B, T_feats, T_text). """ B = len(text_lengths) T_text = text_lengths.max() T_feats = feats_lengths.max() bb_prior = torch.full((B, T_feats, T_text), fill_value=-np.inf) for bidx in range(B): T = feats_lengths[bidx].item() N = text_lengths[bidx].item() key = str(T) + "," + str(N) if self.cache_prior and key in self._cache: prob = self._cache[key] else: alpha = w * np.arange(1, T + 1, dtype=float) # (T,) beta = w * np.array([T - t + 1 for t in alpha]) k = np.arange(N) batched_k = k[..., None] # (N,1) prob = betabinom.logpmf(batched_k, N, alpha, beta) # (N,T) # store cache if self.cache_prior and key not in self._cache: self._cache[key] = prob prob = torch.from_numpy(prob).transpose(0, 1) # -> (T,N) bb_prior[bidx, :T, :N] = prob return bb_prior