Source code for espnet2.gan_tts.vits.loss

# Copyright 2021 Tomoki Hayashi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""VITS-related loss modules.

This code is based on https://github.com/jaywalnut310/vits.

"""

import torch
import torch.distributions as D


[docs]class KLDivergenceLoss(torch.nn.Module): """KL divergence loss."""
[docs] def forward( self, z_p: torch.Tensor, logs_q: torch.Tensor, m_p: torch.Tensor, logs_p: torch.Tensor, z_mask: torch.Tensor, ) -> torch.Tensor: """Calculate KL divergence loss. Args: z_p (Tensor): Flow hidden representation (B, H, T_feats). logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). z_mask (Tensor): Mask tensor (B, 1, T_feats). Returns: Tensor: KL divergence loss. """ z_p = z_p.float() logs_q = logs_q.float() m_p = m_p.float() logs_p = logs_p.float() z_mask = z_mask.float() kl = logs_p - logs_q - 0.5 kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) kl = torch.sum(kl * z_mask) loss = kl / torch.sum(z_mask) return loss
[docs]class KLDivergenceLossWithoutFlow(torch.nn.Module): """KL divergence loss without flow."""
[docs] def forward( self, m_q: torch.Tensor, logs_q: torch.Tensor, m_p: torch.Tensor, logs_p: torch.Tensor, ) -> torch.Tensor: """Calculate KL divergence loss without flow. Args: m_q (Tensor): Posterior encoder projected mean (B, H, T_feats). logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). """ posterior_norm = D.Normal(m_q, torch.exp(logs_q)) prior_norm = D.Normal(m_p, torch.exp(logs_p)) loss = D.kl_divergence(posterior_norm, prior_norm).mean() return loss