Source code for espnet2.gan_svs.vits.generator

# Copyright 2021 Tomoki Hayashi
# Copyright 2022 Yifeng Yu
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Generator module in VISinger.

This code is based on https://github.com/jaywalnut310/vits.

    This is a module of VISinger described in `VISinger: Variational Inference
      with Adversarial Learning for End-to-End Singing Voice Synthesis`_.

    .. _`VISinger: Variational Inference with Adversarial Learning for
      End-to-End Singing Voice Synthesis`: https://arxiv.org/abs/2110.08813

"""

import math
from typing import List, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F

from espnet2.gan_svs.avocodo import AvocodoGenerator
from espnet2.gan_svs.uhifigan import UHiFiGANGenerator
from espnet2.gan_svs.uhifigan.sine_generator import SineGen
from espnet2.gan_svs.utils.expand_f0 import expand_f0
from espnet2.gan_svs.visinger2 import (
    Generator_Harm,
    Generator_Noise,
    VISinger2VocoderGenerator,
)
from espnet2.gan_svs.visinger2.ddsp import upsample
from espnet2.gan_svs.vits.duration_predictor import DurationPredictor
from espnet2.gan_svs.vits.length_regulator import LengthRegulator
from espnet2.gan_svs.vits.phoneme_predictor import PhonemePredictor
from espnet2.gan_svs.vits.pitch_predictor import Decoder
from espnet2.gan_svs.vits.prior_decoder import PriorDecoder
from espnet2.gan_svs.vits.text_encoder import TextEncoder
from espnet2.gan_tts.hifigan import HiFiGANGenerator
from espnet2.gan_tts.utils import get_random_segments, get_segments
from espnet2.gan_tts.vits.posterior_encoder import PosteriorEncoder
from espnet2.gan_tts.vits.residual_coupling import ResidualAffineCouplingBlock


[docs]class VISingerGenerator(torch.nn.Module): """Generator module in VISinger.""" def __init__( self, vocabs: int, aux_channels: int = 513, hidden_channels: int = 192, spks: Optional[int] = None, langs: Optional[int] = None, spk_embed_dim: Optional[int] = None, global_channels: int = -1, segment_size: int = 32, text_encoder_attention_heads: int = 2, text_encoder_ffn_expand: int = 4, text_encoder_blocks: int = 6, text_encoder_positionwise_layer_type: str = "conv1d", text_encoder_positionwise_conv_kernel_size: int = 1, text_encoder_positional_encoding_layer_type: str = "rel_pos", text_encoder_self_attention_layer_type: str = "rel_selfattn", text_encoder_activation_type: str = "swish", text_encoder_normalize_before: bool = True, text_encoder_dropout_rate: float = 0.1, text_encoder_positional_dropout_rate: float = 0.0, text_encoder_attention_dropout_rate: float = 0.0, text_encoder_conformer_kernel_size: int = 7, use_macaron_style_in_text_encoder: bool = True, use_conformer_conv_in_text_encoder: bool = True, decoder_kernel_size: int = 7, decoder_channels: int = 512, decoder_downsample_scales: List[int] = [2, 2, 8, 8], decoder_downsample_kernel_sizes: List[int] = [4, 4, 16, 16], decoder_upsample_scales: List[int] = [8, 8, 2, 2], decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4], decoder_resblock_kernel_sizes: List[int] = [3, 7, 11], decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], # avocodo use_avocodo=False, projection_filters: List[int] = [0, 1, 1, 1], projection_kernels: List[int] = [0, 5, 7, 11], # visinger 2 n_harmonic: int = 64, use_weight_norm_in_decoder: bool = True, posterior_encoder_kernel_size: int = 5, posterior_encoder_layers: int = 16, posterior_encoder_stacks: int = 1, posterior_encoder_base_dilation: int = 1, posterior_encoder_dropout_rate: float = 0.0, use_weight_norm_in_posterior_encoder: bool = True, flow_flows: int = 4, flow_kernel_size: int = 5, flow_base_dilation: int = 1, flow_layers: int = 4, flow_dropout_rate: float = 0.0, use_weight_norm_in_flow: bool = True, use_only_mean_in_flow: bool = True, generator_type: str = "visinger", vocoder_generator_type: str = "hifigan", fs: int = 22050, hop_length: int = 256, win_length: int = 1024, n_fft: int = 1024, use_phoneme_predictor: bool = False, expand_f0_method: str = "repeat", ): """Initialize VITS generator module. Args: vocabs (int): Input vocabulary size. aux_channels (int): Number of acoustic feature channels. hidden_channels (int): Number of hidden channels. spks (Optional[int]): Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. langs (Optional[int]): Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. global_channels (int): Number of global conditioning channels. segment_size (int): Segment size for decoder. text_encoder_attention_heads (int): Number of heads in conformer block of text encoder. text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block of text encoder. text_encoder_blocks (int): Number of conformer blocks in text encoder. text_encoder_positionwise_layer_type (str): Position-wise layer type in conformer block of text encoder. text_encoder_positionwise_conv_kernel_size (int): Position-wise convolution kernel size in conformer block of text encoder. Only used when the above layer type is conv1d or conv1d-linear. text_encoder_positional_encoding_layer_type (str): Positional encoding layer type in conformer block of text encoder. text_encoder_self_attention_layer_type (str): Self-attention layer type in conformer block of text encoder. text_encoder_activation_type (str): Activation function type in conformer block of text encoder. text_encoder_normalize_before (bool): Whether to apply layer norm before self-attention in conformer block of text encoder. text_encoder_dropout_rate (float): Dropout rate in conformer block of text encoder. text_encoder_positional_dropout_rate (float): Dropout rate for positional encoding in conformer block of text encoder. text_encoder_attention_dropout_rate (float): Dropout rate for attention in conformer block of text encoder. text_encoder_conformer_kernel_size (int): Conformer conv kernel size. It will be used when only use_conformer_conv_in_text_encoder = True. use_macaron_style_in_text_encoder (bool): Whether to use macaron style FFN in conformer block of text encoder. use_conformer_conv_in_text_encoder (bool): Whether to use covolution in conformer block of text encoder. decoder_kernel_size (int): Decoder kernel size. decoder_channels (int): Number of decoder initial channels. decoder_downsample_scales (List[int]): List of downsampling scales in decoder. decoder_downsample_kernel_sizes (List[int]): List of kernel sizes for downsampling layers in decoder. decoder_upsample_scales (List[int]): List of upsampling scales in decoder. decoder_upsample_kernel_sizes (List[int]): List of kernel sizes for upsampling layers in decoder. decoder_resblock_kernel_sizes (List[int]): List of kernel sizes for resblocks in decoder. decoder_resblock_dilations (List[List[int]]): List of list of dilations for resblocks in decoder. use_avocodo (bool): Whether to use Avocodo model in the generator. projection_filters (List[int]): List of projection filter sizes. projection_kernels (List[int]): List of projection kernel sizes. n_harmonic (int): Number of harmonic components. use_weight_norm_in_decoder (bool): Whether to apply weight normalization in decoder. posterior_encoder_kernel_size (int): Posterior encoder kernel size. posterior_encoder_layers (int): Number of layers of posterior encoder. posterior_encoder_stacks (int): Number of stacks of posterior encoder. posterior_encoder_base_dilation (int): Base dilation of posterior encoder. posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder. use_weight_norm_in_posterior_encoder (bool): Whether to apply weight normalization in posterior encoder. flow_flows (int): Number of flows in flow. flow_kernel_size (int): Kernel size in flow. flow_base_dilation (int): Base dilation in flow. flow_layers (int): Number of layers in flow. flow_dropout_rate (float): Dropout rate in flow use_weight_norm_in_flow (bool): Whether to apply weight normalization in flow. use_only_mean_in_flow (bool): Whether to use only mean in flow. generator_type (str): Type of generator to use for the model. vocoder_generator_type (str): Type of vocoder generator to use for the model. fs (int): Sample rate of the audio. hop_length (int): Number of samples between successive frames in STFT. win_length (int): Window size of the STFT. n_fft (int): Length of the FFT window to be used. use_phoneme_predictor (bool): Whether to use phoneme predictor in the model. expand_f0_method (str): The method used to expand F0. Use "repeat" or "interpolation". """ super().__init__() self.aux_channels = aux_channels self.hidden_channels = hidden_channels self.generator_type = generator_type self.segment_size = segment_size self.sample_rate = fs self.hop_length = hop_length self.use_avocodo = use_avocodo self.use_flow = True if flow_flows > 0 else False self.use_phoneme_predictor = use_phoneme_predictor self.text_encoder = TextEncoder( vocabs=vocabs, attention_dim=hidden_channels, attention_heads=text_encoder_attention_heads, linear_units=hidden_channels * text_encoder_ffn_expand, blocks=text_encoder_blocks, positionwise_layer_type=text_encoder_positionwise_layer_type, positionwise_conv_kernel_size=text_encoder_positionwise_conv_kernel_size, positional_encoding_layer_type=text_encoder_positional_encoding_layer_type, self_attention_layer_type=text_encoder_self_attention_layer_type, activation_type=text_encoder_activation_type, normalize_before=text_encoder_normalize_before, dropout_rate=text_encoder_dropout_rate, positional_dropout_rate=text_encoder_positional_dropout_rate, attention_dropout_rate=text_encoder_attention_dropout_rate, conformer_kernel_size=text_encoder_conformer_kernel_size, use_macaron_style=use_macaron_style_in_text_encoder, use_conformer_conv=use_conformer_conv_in_text_encoder, ) if vocoder_generator_type == "uhifigan": self.decoder = UHiFiGANGenerator( in_channels=hidden_channels, out_channels=1, channels=decoder_channels, global_channels=global_channels, kernel_size=decoder_kernel_size, downsample_scales=decoder_downsample_scales, downsample_kernel_sizes=decoder_downsample_kernel_sizes, upsample_scales=decoder_upsample_scales, upsample_kernel_sizes=decoder_upsample_kernel_sizes, resblock_kernel_sizes=decoder_resblock_kernel_sizes, resblock_dilations=decoder_resblock_dilations, use_weight_norm=use_weight_norm_in_decoder, use_avocodo=use_avocodo, ) self.sine_generator = SineGen( sample_rate=fs, ) elif vocoder_generator_type == "hifigan": self.decoder = HiFiGANGenerator( in_channels=hidden_channels, out_channels=1, channels=decoder_channels, global_channels=global_channels, kernel_size=decoder_kernel_size, upsample_scales=decoder_upsample_scales, upsample_kernel_sizes=decoder_upsample_kernel_sizes, resblock_kernel_sizes=decoder_resblock_kernel_sizes, resblock_dilations=decoder_resblock_dilations, use_weight_norm=use_weight_norm_in_decoder, ) elif vocoder_generator_type == "avocodo": self.decoder = AvocodoGenerator( in_channels=hidden_channels, out_channels=1, channels=decoder_channels, global_channels=global_channels, kernel_size=decoder_kernel_size, upsample_scales=decoder_upsample_scales, upsample_kernel_sizes=decoder_upsample_kernel_sizes, resblock_kernel_sizes=decoder_resblock_kernel_sizes, resblock_dilations=decoder_resblock_dilations, projection_filters=projection_filters, projection_kernels=projection_kernels, use_weight_norm=use_weight_norm_in_decoder, ) elif vocoder_generator_type == "visinger2": self.decoder = VISinger2VocoderGenerator( in_channels=hidden_channels, out_channels=1, channels=decoder_channels, global_channels=global_channels, kernel_size=decoder_kernel_size, upsample_scales=decoder_upsample_scales, upsample_kernel_sizes=decoder_upsample_kernel_sizes, resblock_kernel_sizes=decoder_resblock_kernel_sizes, resblock_dilations=decoder_resblock_dilations, use_weight_norm=use_weight_norm_in_decoder, n_harmonic=n_harmonic, ) self.dec_harm = Generator_Harm( hidden_channels=hidden_channels, n_harmonic=n_harmonic, kernel_size=3, padding=1, dropout_rate=0.1, sample_rate=fs, hop_size=hop_length, ) self.dec_noise = Generator_Noise( win_length=win_length, hop_length=hop_length, n_fft=n_fft, hidden_channels=hidden_channels, kernel_size=3, padding=1, dropout_rate=0.1, ) self.sin_prenet = torch.nn.Conv1d(1, n_harmonic + 2, 3, padding=1) else: raise ValueError( f"Not supported vocoder generator type: {vocoder_generator_type}" ) self.posterior_encoder = PosteriorEncoder( in_channels=aux_channels, out_channels=hidden_channels, hidden_channels=hidden_channels, kernel_size=posterior_encoder_kernel_size, layers=posterior_encoder_layers, stacks=posterior_encoder_stacks, base_dilation=posterior_encoder_base_dilation, global_channels=global_channels, dropout_rate=posterior_encoder_dropout_rate, use_weight_norm=use_weight_norm_in_posterior_encoder, ) if self.use_flow: self.flow = ResidualAffineCouplingBlock( in_channels=hidden_channels, hidden_channels=hidden_channels, flows=flow_flows, kernel_size=flow_kernel_size, base_dilation=flow_base_dilation, layers=flow_layers, global_channels=global_channels, dropout_rate=flow_dropout_rate, use_weight_norm=use_weight_norm_in_flow, use_only_mean=use_only_mean_in_flow, ) self.f0_prenet = torch.nn.Conv1d(1, hidden_channels + 2, 3, padding=1) if generator_type == "visinger2": self.energy_prenet = torch.nn.Conv1d(1, hidden_channels + 2, 3, padding=1) self.mel_prenet = torch.nn.Conv1d( aux_channels, hidden_channels + 2, 3, padding=1 ) # TODO(kan-bayashi): Add deterministic version as an option self.duration_predictor = DurationPredictor( channels=hidden_channels, filter_channels=256, kernel_size=3, dropout_rate=0.5, global_channels=global_channels, ) self.lr = LengthRegulator() if self.use_phoneme_predictor: self.phoneme_predictor = PhonemePredictor( vocabs=vocabs, hidden_channels=hidden_channels, attention_dim=hidden_channels, blocks=2, ) self.f0_decoder = Decoder( 1, attention_dim=hidden_channels, attention_heads=text_encoder_attention_heads, linear_units=hidden_channels * text_encoder_ffn_expand, blocks=text_encoder_blocks, pw_layer_type=text_encoder_positionwise_layer_type, pw_conv_kernel_size=text_encoder_positionwise_conv_kernel_size, pos_enc_layer_type=text_encoder_positional_encoding_layer_type, self_attention_layer_type=text_encoder_self_attention_layer_type, activation_type=text_encoder_activation_type, normalize_before=text_encoder_normalize_before, dropout_rate=text_encoder_dropout_rate, positional_dropout_rate=text_encoder_positional_dropout_rate, attention_dropout_rate=text_encoder_attention_dropout_rate, conformer_kernel_size=text_encoder_conformer_kernel_size, use_macaron_style=use_macaron_style_in_text_encoder, use_conformer_conv=use_conformer_conv_in_text_encoder, global_channels=global_channels, ) if self.generator_type == "visinger2": self.mel_decoder = Decoder( out_channels=aux_channels, attention_dim=hidden_channels, attention_heads=text_encoder_attention_heads, linear_units=hidden_channels * text_encoder_ffn_expand, blocks=text_encoder_blocks, pw_layer_type=text_encoder_positionwise_layer_type, pw_conv_kernel_size=text_encoder_positionwise_conv_kernel_size, pos_enc_layer_type=text_encoder_positional_encoding_layer_type, self_attention_layer_type=text_encoder_self_attention_layer_type, activation_type=text_encoder_activation_type, normalize_before=text_encoder_normalize_before, dropout_rate=text_encoder_dropout_rate, positional_dropout_rate=text_encoder_positional_dropout_rate, attention_dropout_rate=text_encoder_attention_dropout_rate, conformer_kernel_size=text_encoder_conformer_kernel_size, use_macaron_style=use_macaron_style_in_text_encoder, use_conformer_conv=use_conformer_conv_in_text_encoder, global_channels=global_channels, ) self.prior_decoder = PriorDecoder( out_channels=hidden_channels * 2, attention_dim=hidden_channels, attention_heads=text_encoder_attention_heads, linear_units=hidden_channels * text_encoder_ffn_expand, blocks=text_encoder_blocks, positionwise_layer_type=text_encoder_positionwise_layer_type, positionwise_conv_kernel_size=text_encoder_positionwise_conv_kernel_size, positional_encoding_layer_type=text_encoder_positional_encoding_layer_type, self_attention_layer_type=text_encoder_self_attention_layer_type, activation_type=text_encoder_activation_type, normalize_before=text_encoder_normalize_before, dropout_rate=text_encoder_dropout_rate, positional_dropout_rate=text_encoder_positional_dropout_rate, attention_dropout_rate=text_encoder_attention_dropout_rate, conformer_kernel_size=text_encoder_conformer_kernel_size, use_macaron_style=use_macaron_style_in_text_encoder, use_conformer_conv=use_conformer_conv_in_text_encoder, global_channels=global_channels, ) self.upsample_factor = int(np.prod(decoder_upsample_scales)) self.spks = None if spks is not None and spks > 1: assert global_channels > 0 self.spks = spks self.global_emb = torch.nn.Embedding(spks, global_channels) self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: assert global_channels > 0 self.spk_embed_dim = spk_embed_dim self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels) self.langs = None if langs is not None and langs > 1: assert global_channels > 0 self.langs = langs self.lang_emb = torch.nn.Embedding(langs, global_channels) self.vocoder_generator_type = vocoder_generator_type self.dropout = torch.nn.Dropout(0.2) self.expand_f0_method = expand_f0_method
[docs] def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, label: torch.Tensor = None, label_lengths: torch.Tensor = None, melody: torch.Tensor = None, gt_dur: torch.Tensor = None, score_dur: torch.Tensor = None, slur: torch.Tensor = None, pitch: torch.Tensor = None, ying: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ], ]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, Lmax, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). label (LongTensor): Batch of padded label ids (B, Tmax). label_lengths (LongTensor): Batch of the lengths of padded label ids (B, ). melody (LongTensor): Batch of padded midi (B, Tmax). gt_dur (LongTensor): Batch of padded ground truth duration (B, Tmax). score_dur (LongTensor): Batch of padded score duration (B, Tmax). pitch (FloatTensor): Batch of padded f0 (B, Tmax). ying (Optional[Tensor]): Batch of padded ying (B, Tmax). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). Returns: Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). Tensor: Duration negative log-likelihood (NLL) tensor (B,). Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text). Tensor: Segments start index tensor (B,). Tensor: Text mask tensor (B, 1, T_text). Tensor: Feature mask tensor (B, 1, T_feats). tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - Tensor: Posterior encoder hidden representation (B, H, T_feats). - Tensor: Flow hidden representation (B, H, T_feats). - Tensor: Expanded text encoder projected mean (B, H, T_feats). - Tensor: Expanded text encoder projected scale (B, H, T_feats). - Tensor: Posterior encoder projected mean (B, H, T_feats). - Tensor: Posterior encoder projected scale (B, H, T_feats). """ # calculate global conditioning g = None if self.spks is not None: # speaker one-hot vector embedding: (B, global_channels, 1) g = self.global_emb(sids.view(-1)).unsqueeze(-1) if self.spk_embed_dim is not None: # pretreined speaker embedding, e.g., X-vector (B, global_channels, 1) g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) if g is None: g = g_ else: g = g + g_ if self.langs is not None: # language one-hot vector embedding: (B, global_channels, 1) g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) if g is None: g = g_ else: g = g + g_ # forward text encoder # Encoder x, x_mask, dur_input, x_pitch = self.text_encoder( label, label_lengths, melody, score_dur, slur ) # dur # Note this is different, we use frame level duration not time level # but it has no big difference on performance predict_dur = self.duration_predictor(dur_input, x_mask, g=g) predict_dur = (torch.exp(predict_dur) - 1) * x_mask predict_dur = predict_dur * self.sample_rate / self.hop_length # LR decoder_input, mel_len = self.lr(x, gt_dur, use_state_info=True) decoder_input_pitch, mel_len = self.lr(x_pitch, gt_dur, use_state_info=True) LF0 = 2595.0 * torch.log10(1.0 + pitch / 700.0) LF0 = LF0 / 500 LF0 = LF0.transpose(1, 2) predict_lf0, predict_bn_mask = self.f0_decoder( decoder_input + decoder_input_pitch, feats_lengths, g=g ) predict_lf0 = torch.max( predict_lf0, torch.zeros_like(predict_lf0).to(predict_lf0) ) if self.generator_type == "visinger2": predict_mel, predict_bn_mask = self.mel_decoder( decoder_input + self.f0_prenet(LF0), feats_lengths, g=g ) predict_energy = ( predict_mel.detach().sum(1).unsqueeze(1) / self.aux_channels ) decoder_input = decoder_input + self.f0_prenet(LF0) if self.generator_type == "visinger2": decoder_input = ( decoder_input + self.energy_prenet(predict_energy) + self.mel_prenet(predict_mel.detach()) ) decoder_output, predict_bn_mask = self.prior_decoder( decoder_input, feats_lengths, g=g ) prior_info = decoder_output prior_mean = prior_info[:, : self.hidden_channels, :] prior_logstd = prior_info[:, self.hidden_channels :, :] # forward posterior encoder posterior_z, posterior_mean, posterior_logstd, y_mask = self.posterior_encoder( feats, feats_lengths, g=g ) if self.use_flow: z_flow = self.flow(posterior_z, y_mask, g=g) else: z_flow = None # phoneme predictor if self.use_phoneme_predictor: log_probs = self.phoneme_predictor(posterior_z, y_mask) else: log_probs = None p_z = posterior_z p_z = self.dropout(p_z) # get random segments z_segments, z_start_idxs = get_random_segments( p_z, feats_lengths, self.segment_size ) if self.vocoder_generator_type == "uhifigan": # get sine wave # def plot_sine_waves(sine_waves, name): # import matplotlib.pyplot as plt # sine_waves_np = sine_waves[0].detach().cpu().numpy() # plt.plot(sine_waves_np) # plt.xlabel("Time (samples)") # plt.ylabel("Amplitude") # plt.title("Sine Wave") # plt.savefig(name + ".png") # plt.close() # plot_sine_waves(pitch_segments[0], "pitch_segments") pitch_segments = get_segments( pitch.transpose(1, 2), z_start_idxs, self.segment_size ) pitch_segments_expended = expand_f0( pitch_segments, self.hop_length, method="repeat" ) # plot_sine_waves( # pitch_segments_expended[0].unsqueeze(0), "pitch_segments_expended" # ) pitch_segments_expended = pitch_segments_expended.reshape( -1, pitch_segments_expended.shape[-1], 1 ) sine_waves, uv, noise = self.sine_generator(pitch_segments_expended) sine_waves = sine_waves.transpose(1, 2) wav = self.decoder(z_segments, excitation=sine_waves, g=g) elif self.vocoder_generator_type == "visinger2": pitch_ = upsample(pitch, self.hop_length) omega = torch.cumsum(2 * math.pi * pitch_ / self.sample_rate, 1) sin = torch.sin(omega).transpose(1, 2) # dsp synthesize pitch = pitch.transpose(1, 2) noise_x = self.dec_noise(posterior_z, y_mask) harm_x = self.dec_harm(pitch, posterior_z, y_mask) # dsp waveform dsp_o = torch.cat([harm_x, noise_x], axis=1) # decoder_condition = torch.cat([harm_x, noise_x, sin], axis=1) decoder_condition = self.sin_prenet(sin) # dsp based HiFiGAN vocoder F0_slice = get_segments(pitch, z_start_idxs, self.segment_size) dsp_slice = get_segments( dsp_o, z_start_idxs * self.hop_length, self.segment_size * self.hop_length, ) condition_slice = get_segments( decoder_condition, z_start_idxs * self.hop_length, self.segment_size * self.hop_length, ) wav = self.decoder(z_segments, condition_slice, g=g) else: wav = self.decoder(z_segments, g=g) # wav = dsp_slice.sum(1, keepdim=True) common_tuple = ( posterior_z, z_flow, prior_mean, prior_logstd, posterior_mean, posterior_logstd, predict_lf0, LF0 * predict_bn_mask, predict_dur, gt_dur, log_probs, ) output = (wav, z_start_idxs, x_mask, y_mask, common_tuple) if self.vocoder_generator_type == "visinger2": output = output + (dsp_slice.sum(1),) if self.generator_type == "visinger2": output = output + (predict_mel,) return output
[docs] def inference( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: Optional[torch.Tensor] = None, feats_lengths: Optional[torch.Tensor] = None, label: torch.Tensor = None, label_lengths: torch.Tensor = None, melody: torch.Tensor = None, score_dur: torch.Tensor = None, slur: torch.Tensor = None, gt_dur: Optional[torch.Tensor] = None, pitch: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, noise_scale: float = 0.667, noise_scale_dur: float = 0.8, alpha: float = 1.0, max_len: Optional[int] = None, use_teacher_forcing: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Run inference. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, Lmax, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). label (LongTensor): Batch of padded label ids (B, Tmax). label_lengths (LongTensor): Batch of the lengths of padded label ids (B, ). melody (LongTensor): Batch of padded midi (B, Tmax). gt_dur (LongTensor): Batch of padded ground truth duration (B, Tmax). score_dur (LongTensor): Batch of padded score duration (B, Tmax). pitch (FloatTensor): Batch of padded f0 (B, Tmax). ying (Optional[Tensor]): Batch of padded ying (B, Tmax). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). noise_scale (float): Noise scale parameter for flow. noise_scale_dur (float): Noise scale parameter for duration predictor. alpha (float): Alpha parameter to control the speed of generated speech. max_len (Optional[int]): Maximum length of acoustic feature sequence. use_teacher_forcing (bool): Whether to use teacher forcing. Returns: Tensor: Generated waveform tensor (B, T_wav). """ # encoder x, x_mask, dur_input, x_pitch = self.text_encoder( label, label_lengths, melody, score_dur, slur ) g = None if self.spks is not None: # (B, global_channels, 1) g = self.global_emb(sids.view(-1)).unsqueeze(-1) if self.spk_embed_dim is not None: # (B, global_channels, 1) g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) if g is None: g = g_ else: g = g + g_ if self.langs is not None: # (B, global_channels, 1) g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) if g is None: g = g_ else: g = g + g_ if use_teacher_forcing: # forward posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) # forward flow if self.use_flow: z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) # decoder pitch = pitch.transpose(0, 1).reshape(1, 1, -1) if self.vocoder_generator_type == "uhifigan": pitch_segments_expended = expand_f0( pitch, self.hop_length, method=self.expand_f0_method ) pitch_segments_expended = pitch_segments_expended.reshape( -1, pitch_segments_expended.shape[-1], 1 ) sine_waves, uv, noise = self.sine_generator(pitch_segments_expended) sine_waves = sine_waves.transpose(1, 2) wav = self.decoder( (z * y_mask)[:, :, :max_len], excitation=sine_waves, g=g ) elif self.vocoder_generator_type == "avocodo": wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)[-1] elif self.vocoder_generator_type == "visinger2": pitch_ = upsample(pitch.transpose(1, 2), self.hop_length) omega = torch.cumsum(2 * math.pi * pitch_ / self.sample_rate, 1) sin = torch.sin(omega).transpose(1, 2) # dsp synthesize noise_x = self.dec_noise(z, y_mask) harm_x = self.dec_harm(pitch, z, y_mask) # dsp waveform dsp_o = torch.cat([harm_x, noise_x], axis=1) # decoder_condition = torch.cat([harm_x, noise_x, sin], axis=1) decoder_condition = self.sin_prenet(sin) # dsp based HiFiGAN vocoder wav = self.decoder((z * y_mask)[:, :, :max_len], decoder_condition, g=g) # wav = dsp_o.sum(1) # wav = noise_x # wav = harm_x.sum(1) else: wav = self.decoder((z * y_mask)[:, :, :max_len], g=g) else: # dur predict_dur = self.duration_predictor(dur_input, x_mask, g=g) predict_dur = (torch.exp(predict_dur) - 1) * x_mask predict_dur = predict_dur * self.sample_rate / self.hop_length predict_dur = torch.max(predict_dur, torch.ones_like(predict_dur).to(x)) predict_dur = torch.ceil(predict_dur).long() predict_dur = predict_dur[:, 0, :] y_lengths = torch.clamp_min(torch.sum(predict_dur, [1]), 1).long() # LR decoder_input, mel_len = self.lr(x, predict_dur, use_state_info=True) decoder_input_pitch, mel_len = self.lr( x_pitch, predict_dur, use_state_info=True ) # aam predict_lf0, predict_bn_mask = self.f0_decoder( decoder_input + decoder_input_pitch, y_lengths, g=g ) if self.generator_type == "visinger2": predict_mel, predict_bn_mask = self.mel_decoder( decoder_input + self.f0_prenet(predict_lf0), y_lengths, g=g, ) predict_energy = predict_mel.sum(1).unsqueeze(1) / self.aux_channels predict_lf0 = torch.max( predict_lf0, torch.zeros_like(predict_lf0).to(predict_lf0) ) decoder_input = decoder_input + self.f0_prenet(predict_lf0) if self.generator_type == "visinger2": decoder_input = ( decoder_input + self.energy_prenet(predict_energy) + self.mel_prenet(predict_mel) ) decoder_output, y_mask = self.prior_decoder(decoder_input, y_lengths, g=g) prior_info = decoder_output m_p = prior_info[:, : self.hidden_channels, :] logs_p = prior_info[:, self.hidden_channels :, :] # decoder z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale if self.use_flow: z = self.flow(z_p, y_mask, g=g, inverse=True) else: z = z_p F0_std = 500 F0 = predict_lf0 * F0_std F0 = F0 / 2595 F0 = torch.pow(10, F0) F0 = (F0 - 1) * 700.0 if self.vocoder_generator_type == "uhifigan": pitch_segments_expended = expand_f0( F0, self.hop_length, method=self.expand_f0_method ) pitch_segments_expended = pitch_segments_expended.reshape( -1, pitch_segments_expended.shape[-1], 1 ) sine_waves, uv, noise = self.sine_generator(pitch_segments_expended) sine_waves = sine_waves.transpose(1, 2) wav = self.decoder( (z * y_mask)[:, :, :max_len], excitation=sine_waves, g=g ) elif self.vocoder_generator_type == "avocodo": wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)[-1] elif self.vocoder_generator_type == "visinger2": pitch_ = upsample(F0.transpose(1, 2), self.hop_length) omega = torch.cumsum(2 * math.pi * pitch_ / self.sample_rate, 1) sin = torch.sin(omega).transpose(1, 2) # dsp synthesize noise_x = self.dec_noise(z, y_mask) harm_x = self.dec_harm(F0, z, y_mask) # dsp waveform dsp_o = torch.cat([harm_x, noise_x], axis=1) # decoder_condition = torch.cat([harm_x, noise_x, sin], axis=1) decoder_condition = self.sin_prenet(sin) # dsp based HiFiGAN vocoder wav = self.decoder((z * y_mask)[:, :, :max_len], decoder_condition, g=g) # wav = dsp_o.sum(1) # wav = noise_x # wav = harm_x.sum(1) else: wav = self.decoder((z * y_mask)[:, :, :max_len], g=g) return wav.squeeze(1)