Source code for espnet2.gan_tts.vits.residual_coupling

# Copyright 2021 Tomoki Hayashi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Residual affine coupling modules in VITS.

This code is based on https://github.com/jaywalnut310/vits.

"""

from typing import Optional, Tuple, Union

import torch

from espnet2.gan_tts.vits.flow import FlipFlow
from espnet2.gan_tts.wavenet import WaveNet


[docs]class ResidualAffineCouplingBlock(torch.nn.Module): """Residual affine coupling block module. This is a module of residual affine coupling block, which used as "Flow" in `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`: https://arxiv.org/abs/2006.04558 """ def __init__( self, in_channels: int = 192, hidden_channels: int = 192, flows: int = 4, kernel_size: int = 5, base_dilation: int = 1, layers: int = 4, global_channels: int = -1, dropout_rate: float = 0.0, use_weight_norm: bool = True, bias: bool = True, use_only_mean: bool = True, ): """Initilize ResidualAffineCouplingBlock module. Args: in_channels (int): Number of input channels. hidden_channels (int): Number of hidden channels. flows (int): Number of flows. kernel_size (int): Kernel size for WaveNet. base_dilation (int): Base dilation factor for WaveNet. layers (int): Number of layers of WaveNet. stacks (int): Number of stacks of WaveNet. global_channels (int): Number of global channels. dropout_rate (float): Dropout rate. use_weight_norm (bool): Whether to use weight normalization in WaveNet. bias (bool): Whether to use bias paramters in WaveNet. use_only_mean (bool): Whether to estimate only mean. """ super().__init__() self.flows = torch.nn.ModuleList() for i in range(flows): self.flows += [ ResidualAffineCouplingLayer( in_channels=in_channels, hidden_channels=hidden_channels, kernel_size=kernel_size, base_dilation=base_dilation, layers=layers, stacks=1, global_channels=global_channels, dropout_rate=dropout_rate, use_weight_norm=use_weight_norm, bias=bias, use_only_mean=use_only_mean, ) ] self.flows += [FlipFlow()]
[docs] def forward( self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, inverse: bool = False, ) -> torch.Tensor: """Calculate forward propagation. Args: x (Tensor): Input tensor (B, in_channels, T). x_lengths (Tensor): Length tensor (B,). g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). inverse (bool): Whether to inverse the flow. Returns: Tensor: Output tensor (B, in_channels, T). """ if not inverse: for flow in self.flows: x, _ = flow(x, x_mask, g=g, inverse=inverse) else: for flow in reversed(self.flows): x = flow(x, x_mask, g=g, inverse=inverse) return x
[docs]class ResidualAffineCouplingLayer(torch.nn.Module): """Residual affine coupling layer.""" def __init__( self, in_channels: int = 192, hidden_channels: int = 192, kernel_size: int = 5, base_dilation: int = 1, layers: int = 5, stacks: int = 1, global_channels: int = -1, dropout_rate: float = 0.0, use_weight_norm: bool = True, bias: bool = True, use_only_mean: bool = True, ): """Initialzie ResidualAffineCouplingLayer module. Args: in_channels (int): Number of input channels. hidden_channels (int): Number of hidden channels. kernel_size (int): Kernel size for WaveNet. base_dilation (int): Base dilation factor for WaveNet. layers (int): Number of layers of WaveNet. stacks (int): Number of stacks of WaveNet. global_channels (int): Number of global channels. dropout_rate (float): Dropout rate. use_weight_norm (bool): Whether to use weight normalization in WaveNet. bias (bool): Whether to use bias paramters in WaveNet. use_only_mean (bool): Whether to estimate only mean. """ assert in_channels % 2 == 0, "in_channels should be divisible by 2" super().__init__() self.half_channels = in_channels // 2 self.use_only_mean = use_only_mean # define modules self.input_conv = torch.nn.Conv1d( self.half_channels, hidden_channels, 1, ) self.encoder = WaveNet( in_channels=-1, out_channels=-1, kernel_size=kernel_size, layers=layers, stacks=stacks, base_dilation=base_dilation, residual_channels=hidden_channels, aux_channels=-1, gate_channels=hidden_channels * 2, skip_channels=hidden_channels, global_channels=global_channels, dropout_rate=dropout_rate, bias=bias, use_weight_norm=use_weight_norm, use_first_conv=False, use_last_conv=False, scale_residual=False, scale_skip_connect=True, ) if use_only_mean: self.proj = torch.nn.Conv1d( hidden_channels, self.half_channels, 1, ) else: self.proj = torch.nn.Conv1d( hidden_channels, self.half_channels * 2, 1, ) self.proj.weight.data.zero_() self.proj.bias.data.zero_()
[docs] def forward( self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, inverse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Calculate forward propagation. Args: x (Tensor): Input tensor (B, in_channels, T). x_lengths (Tensor): Length tensor (B,). g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). inverse (bool): Whether to inverse the flow. Returns: Tensor: Output tensor (B, in_channels, T). Tensor: Log-determinant tensor for NLL (B,) if not inverse. """ xa, xb = x.split(x.size(1) // 2, dim=1) h = self.input_conv(xa) * x_mask h = self.encoder(h, x_mask, g=g) stats = self.proj(h) * x_mask if not self.use_only_mean: m, logs = stats.split(stats.size(1) // 2, dim=1) else: m = stats logs = torch.zeros_like(m) if not inverse: xb = m + xb * torch.exp(logs) * x_mask x = torch.cat([xa, xb], 1) logdet = torch.sum(logs, [1, 2]) return x, logdet else: xb = (xb - m) * torch.exp(-logs) * x_mask x = torch.cat([xa, xb], 1) return x