Source code for espnet2.gan_svs.joint.joint_score2wav

# Copyright 2021 Tomoki Hayashi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Joint text-to-wav module for end-to-end training."""

from typing import Any, Dict, Optional

import torch
from typeguard import check_argument_types

from espnet2.gan_svs.abs_gan_svs import AbsGANSVS
from espnet2.gan_tts.hifigan import (
    HiFiGANGenerator,
    HiFiGANMultiPeriodDiscriminator,
    HiFiGANMultiScaleDiscriminator,
    HiFiGANMultiScaleMultiPeriodDiscriminator,
    HiFiGANPeriodDiscriminator,
    HiFiGANScaleDiscriminator,
)
from espnet2.gan_tts.hifigan.loss import (
    DiscriminatorAdversarialLoss,
    FeatureMatchLoss,
    GeneratorAdversarialLoss,
    MelSpectrogramLoss,
)
from espnet2.gan_tts.melgan import MelGANGenerator, MelGANMultiScaleDiscriminator
from espnet2.gan_tts.melgan.pqmf import PQMF
from espnet2.gan_tts.parallel_wavegan import (
    ParallelWaveGANDiscriminator,
    ParallelWaveGANGenerator,
)
from espnet2.gan_tts.style_melgan import StyleMelGANDiscriminator, StyleMelGANGenerator
from espnet2.gan_tts.utils import get_random_segments, get_segments
from espnet2.svs.naive_rnn.naive_rnn_dp import NaiveRNNDP
from espnet2.svs.xiaoice.XiaoiceSing import XiaoiceSing
from espnet2.torch_utils.device_funcs import force_gatherable

AVAILABLE_SCORE2MEL = {
    "xiaoice": XiaoiceSing,
    "naive_rnn_dp": NaiveRNNDP,
}
AVAILABLE_VOCODER = {
    "hifigan_generator": HiFiGANGenerator,
    "melgan_generator": MelGANGenerator,
    "parallel_wavegan_generator": ParallelWaveGANGenerator,
    "style_melgan_generator": StyleMelGANGenerator,
}
AVAILABLE_DISCRIMINATORS = {
    "hifigan_period_discriminator": HiFiGANPeriodDiscriminator,
    "hifigan_scale_discriminator": HiFiGANScaleDiscriminator,
    "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator,
    "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator,
    "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator,  # NOQA
    "melgan_multi_scale_discriminator": MelGANMultiScaleDiscriminator,
    "parallel_wavegan_discriminator": ParallelWaveGANDiscriminator,
    "style_melgan_discriminator": StyleMelGANDiscriminator,
}


[docs]class JointScore2Wav(AbsGANSVS): """General class to jointly train score2mel and vocoder parts.""" def __init__( self, # generator (score2mel + vocoder) related idim: int, odim: int, segment_size: int = 32, sampling_rate: int = 22050, score2mel_type: str = "xiaoice", score2mel_params: Dict[str, Any] = { "midi_dim": 129, "tempo_dim": 500, "embed_dim": 512, "adim": 384, "aheads": 4, "elayers": 6, "eunits": 1536, "dlayers": 6, "dunits": 1536, "postnet_layers": 5, "postnet_chans": 512, "postnet_filts": 5, "postnet_dropout_rate": 0.5, "positionwise_layer_type": "conv1d", "positionwise_conv_kernel_size": 1, "use_scaled_pos_enc": True, "use_batch_norm": True, "encoder_normalize_before": True, "decoder_normalize_before": True, "encoder_concat_after": False, "decoder_concat_after": False, "duration_predictor_layers": 2, "duration_predictor_chans": 384, "duration_predictor_kernel_size": 3, "duration_predictor_dropout_rate": 0.1, "reduction_factor": 1, "encoder_type": "transformer", "decoder_type": "transformer", "transformer_enc_dropout_rate": 0.1, "transformer_enc_positional_dropout_rate": 0.1, "transformer_enc_attn_dropout_rate": 0.1, "transformer_dec_dropout_rate": 0.1, "transformer_dec_positional_dropout_rate": 0.1, "transformer_dec_attn_dropout_rate": 0.1, # only for conformer "conformer_rel_pos_type": "latest", "conformer_pos_enc_layer_type": "rel_pos", "conformer_self_attn_layer_type": "rel_selfattn", "conformer_activation_type": "swish", "use_macaron_style_in_conformer": True, "use_cnn_in_conformer": True, "zero_triu": False, "conformer_enc_kernel_size": 7, "conformer_dec_kernel_size": 31, # extra embedding related "spks": None, "langs": None, "spk_embed_dim": None, "spk_embed_integration_type": "add", # training related "init_type": "xavier_uniform", "init_enc_alpha": 1.0, "init_dec_alpha": 1.0, "use_masking": False, "use_weighted_masking": False, "loss_type": "L1", }, vocoder_type: str = "hifigan_generator", vocoder_params: Dict[str, Any] = { "out_channels": 1, "channels": 512, "global_channels": -1, "kernel_size": 7, "upsample_scales": [8, 8, 2, 2], "upsample_kernel_sizes": [16, 16, 4, 4], "resblock_kernel_sizes": [3, 7, 11], "resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "use_additional_convs": True, "bias": True, "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": {"negative_slope": 0.1}, "use_weight_norm": True, }, use_pqmf: bool = False, pqmf_params: Dict[str, Any] = { "subbands": 4, "taps": 62, "cutoff_ratio": 0.142, "beta": 9.0, }, # discriminator related discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", discriminator_params: Dict[str, Any] = { "scales": 1, "scale_downsample_pooling": "AvgPool1d", "scale_downsample_pooling_params": { "kernel_size": 4, "stride": 2, "padding": 2, }, "scale_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [15, 41, 5, 3], "channels": 128, "max_downsample_channels": 1024, "max_groups": 16, "bias": True, "downsample_scales": [2, 2, 4, 4, 1], "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": {"negative_slope": 0.1}, "use_weight_norm": True, "use_spectral_norm": False, }, "follow_official_norm": False, "periods": [2, 3, 5, 7, 11], "period_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [5, 3], "channels": 32, "downsample_scales": [3, 3, 3, 3, 1], "max_downsample_channels": 1024, "bias": True, "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": {"negative_slope": 0.1}, "use_weight_norm": True, "use_spectral_norm": False, }, }, # loss related generator_adv_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "loss_type": "mse", }, discriminator_adv_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "loss_type": "mse", }, use_feat_match_loss: bool = True, feat_match_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "average_by_layers": False, "include_final_outputs": True, }, use_mel_loss: bool = True, mel_loss_params: Dict[str, Any] = { "fs": 22050, "n_fft": 1024, "hop_length": 256, "win_length": None, "window": "hann", "n_mels": 80, "fmin": 0, "fmax": None, "log_base": None, }, lambda_score2mel: float = 1.0, lambda_adv: float = 1.0, lambda_feat_match: float = 2.0, lambda_mel: float = 45.0, cache_generator_outputs: bool = False, ): """Initialize JointScore2Wav module. Args: idim (int): Input vocabrary size. odim (int): Acoustic feature dimension. The actual output channels will be 1 since the model is the end-to-end text-to-wave model but for the compatibility odim is used to indicate the acoustic feature dimension. segment_size (int): Segment size for random windowed inputs. sampling_rate (int): Sampling rate, not used for the training but it will be referred in saving waveform during the inference. text2mel_type (str): The text2mel model type. text2mel_params (Dict[str, Any]): Parameter dict for text2mel model. use_pqmf (bool): Whether to use PQMF for multi-band vocoder. pqmf_params (Dict[str, Any]): Parameter dict for PQMF module. vocoder_type (str): The vocoder model type. vocoder_params (Dict[str, Any]): Parameter dict for vocoder model. discriminator_type (str): Discriminator type. discriminator_params (Dict[str, Any]): Parameter dict for discriminator. generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator adversarial loss. discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for discriminator adversarial loss. use_feat_match_loss (bool): Whether to use feat match loss. feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. use_mel_loss (bool): Whether to use mel loss. mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. lambda_text2mel (float): Loss scaling coefficient for text2mel model loss. lambda_adv (float): Loss scaling coefficient for adversarial loss. lambda_feat_match (float): Loss scaling coefficient for feat match loss. lambda_mel (float): Loss scaling coefficient for mel loss. cache_generator_outputs (bool): Whether to cache generator outputs. """ assert check_argument_types() super().__init__() self.segment_size = segment_size self.use_pqmf = use_pqmf # define modules self.generator = torch.nn.ModuleDict() score2mel_class = AVAILABLE_SCORE2MEL[score2mel_type] score2mel_params.update(idim=idim, odim=odim) self.generator["score2mel"] = score2mel_class( **score2mel_params, ) vocoder_class = AVAILABLE_VOCODER[vocoder_type] if vocoder_type in ["hifigan_generator", "melgan_generator"]: vocoder_params.update(in_channels=odim) elif vocoder_type in ["parallel_wavegan_generator", "style_melgan_generator"]: vocoder_params.update(aux_channels=odim) self.generator["vocoder"] = vocoder_class( **vocoder_params, ) if self.use_pqmf: self.pqmf = PQMF(**pqmf_params) discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] self.discriminator = discriminator_class( **discriminator_params, ) self.generator_adv_loss = GeneratorAdversarialLoss( **generator_adv_loss_params, ) self.discriminator_adv_loss = DiscriminatorAdversarialLoss( **discriminator_adv_loss_params, ) self.use_feat_match_loss = use_feat_match_loss if self.use_feat_match_loss: self.feat_match_loss = FeatureMatchLoss( **feat_match_loss_params, ) self.use_mel_loss = use_mel_loss if self.use_mel_loss: self.mel_loss = MelSpectrogramLoss( **mel_loss_params, ) # coefficients self.lambda_score2mel = lambda_score2mel self.lambda_adv = lambda_adv if self.use_feat_match_loss: self.lambda_feat_match = lambda_feat_match if self.use_mel_loss: self.lambda_mel = lambda_mel # cache self.cache_generator_outputs = cache_generator_outputs self._cache = None # store sampling rate for saving wav file # (not used for the training) self.fs = sampling_rate # store parameters for test compatibility self.spks = self.generator["score2mel"].spks self.langs = self.generator["score2mel"].langs self.spk_embed_dim = self.generator["score2mel"].spk_embed_dim @property def require_raw_singing(self): """Return whether or not singing is required.""" return True @property def require_vocoder(self): """Return whether or not vocoder is required.""" return False
[docs] def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, singing: torch.Tensor, singing_lengths: torch.Tensor, label: Optional[Dict[str, torch.Tensor]] = None, label_lengths: Optional[Dict[str, torch.Tensor]] = None, melody: Optional[Dict[str, torch.Tensor]] = None, melody_lengths: Optional[Dict[str, torch.Tensor]] = None, pitch: torch.LongTensor = None, pitch_lengths: torch.Tensor = None, duration: Optional[Dict[str, torch.Tensor]] = None, duration_lengths: Optional[Dict[str, torch.Tensor]] = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, forward_generator: bool = True, ) -> Dict[str, Any]: """Perform generator forward. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, Lmax, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). singing (Tensor): Singing waveform tensor (B, T_wav). singing_lengths (Tensor): Singing length tensor (B,). label (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded label ids (B, Tmax). label_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded label ids (B, ). melody (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded melody (B, Tmax). melody_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded melody (B, ). pitch (FloatTensor): Batch of padded f0 (B, Tmax). pitch_lengths (LongTensor): Batch of the lengths of padded f0 (B, ). duration (Optional[Dict]): key is "lab", "score_phn" or "score_syb"; value (LongTensor): Batch of padded duration (B, Tmax). duration_length (Optional[Dict]): key is "lab", "score_phn" or "score_syb"; value (LongTensor): Batch of the lengths of padded duration (B, ). slur (FloatTensor): Batch of padded slur (B, Tmax). slur_lengths (LongTensor): Batch of the lengths of padded slur (B, ). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). forward_generator (bool): Whether to forward generator. Returns: Dict[str, Any]: - loss (Tensor): Loss scalar tensor. - stats (Dict[str, float]): Statistics to be monitored. - weight (Tensor): Weight tensor to summarize losses. - optim_idx (int): Optimizer index (0 for G and 1 for D). """ beat = duration["lab"] beat_lengths = duration_lengths["lab"] duration = duration["lab"] label = label["score"] label_lengths = label_lengths["score"] melody = melody["score"] melody_lengths = melody_lengths["score"] if forward_generator: return self._forward_generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, singing=singing, singing_lengths=singing_lengths, duration=duration, label=label, label_lengths=label_lengths, melody=melody, melody_lengths=melody_lengths, beat=beat, beat_lengths=beat_lengths, pitch=pitch, pitch_lengths=pitch_lengths, sids=sids, spembs=spembs, lids=lids, ) else: return self._forward_discrminator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, singing=singing, singing_lengths=singing_lengths, duration=duration, label=label, label_lengths=label_lengths, melody=melody, melody_lengths=melody_lengths, beat=beat, beat_lengths=beat_lengths, pitch=pitch, pitch_lengths=pitch_lengths, sids=sids, spembs=spembs, lids=lids, )
def _forward_generator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, singing: torch.Tensor, singing_lengths: torch.Tensor, duration: torch.Tensor, label: Optional[Dict[str, torch.Tensor]] = None, label_lengths: Optional[Dict[str, torch.Tensor]] = None, melody: Optional[Dict[str, torch.Tensor]] = None, melody_lengths: Optional[Dict[str, torch.Tensor]] = None, beat: Optional[Dict[str, torch.Tensor]] = None, beat_lengths: Optional[Dict[str, torch.Tensor]] = None, pitch: Optional[torch.Tensor] = None, pitch_lengths: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Dict[str, Any]: """Perform generator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). singing (Tensor): Singing waveform tensor (B, T_wav). singing_lengths (Tensor): Singing length tensor (B,). duration (Optional[Dict]): key is "phn", "syb"; value (LongTensor): Batch of padded beat (B, Tmax). label (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded label ids (B, Tmax). label_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded label ids (B, ). melody (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded melody (B, Tmax). melody_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded melody (B, ). tempo (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded tempo (B, Tmax). tempo_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded tempo (B, ). beat (Optional[Dict]): key is "lab", "score_phn" or "score_syb"; value (LongTensor): Batch of padded beat (B, Tmax). beat_length (Optional[Dict]): key is "lab", "score_phn" or "score_syb"; value (LongTensor): Batch of the lengths of padded beat (B, ). pitch (FloatTensor): Batch of padded f0 (B, Tmax). pitch_lengths (LongTensor): Batch of the lengths of padded f0 (B, ). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) singing = singing.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False # calculate text2mel outputs score2mel_loss, stats, feats_gen = self.generator["score2mel"]( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, label=label, label_lengths=label_lengths, melody=melody, melody_lengths=melody_lengths, duration=beat, duration_lengths=beat_lengths, pitch=pitch, pitch_lengths=pitch_lengths, sids=sids, spembs=spembs, lids=lids, joint_training=True, ) # get random segments feats_gen_, start_idxs = get_random_segments( x=feats_gen.transpose(1, 2), x_lengths=feats_lengths, segment_size=self.segment_size, ) # calculate vocoder outputs singing_hat_ = self.generator["vocoder"](feats_gen_) if self.use_pqmf: singing_hat_ = self.pqmf.synthesis(singing_hat_) else: score2mel_loss, stats, singing_hat_, start_idxs = self._cache # store cache if self.training and self.cache_generator_outputs and not reuse_cache: self._cache = (score2mel_loss, stats, singing_hat_, start_idxs) singing_ = get_segments( x=singing, start_idxs=start_idxs * self.generator["vocoder"].upsample_factor, segment_size=self.segment_size * self.generator["vocoder"].upsample_factor, ) # calculate discriminator outputs p_hat = self.discriminator(singing_hat_) with torch.no_grad(): # do not store discriminator gradient in generator turn p = self.discriminator(singing_) # calculate losses adv_loss = self.generator_adv_loss(p_hat) adv_loss = adv_loss * self.lambda_adv score2mel_loss = score2mel_loss * self.lambda_score2mel loss = adv_loss + score2mel_loss if self.use_feat_match_loss: feat_match_loss = self.feat_match_loss(p_hat, p) feat_match_loss = feat_match_loss * self.lambda_feat_match loss = loss + feat_match_loss stats.update(feat_match_loss=feat_match_loss.item()) if self.use_mel_loss: mel_loss = self.mel_loss(singing_hat_, singing_) mel_loss = self.lambda_mel * mel_loss loss = loss + mel_loss stats.update(mel_loss=mel_loss.item()) stats.update( adv_loss=adv_loss.item(), score2mel_loss=score2mel_loss.item(), loss=loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 0, # needed for trainer } def _forward_discrminator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, singing: torch.Tensor, singing_lengths: torch.Tensor, duration: torch.Tensor, label: Optional[Dict[str, torch.Tensor]] = None, label_lengths: Optional[Dict[str, torch.Tensor]] = None, melody: Optional[Dict[str, torch.Tensor]] = None, melody_lengths: Optional[Dict[str, torch.Tensor]] = None, beat: Optional[Dict[str, torch.Tensor]] = None, beat_lengths: Optional[Dict[str, torch.Tensor]] = None, pitch: Optional[torch.Tensor] = None, pitch_lengths: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Dict[str, Any]: """Perform discriminator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). singing (Tensor): Singing waveform tensor (B, T_wav). singing_lengths (Tensor): Singing length tensor (B,). duration (Optional[Dict]): key is "phn", "syb"; value (LongTensor): Batch of padded beat (B, Tmax). label (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded label ids (B, Tmax). label_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded label ids (B, ). melody (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded melody (B, Tmax). melody_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded melody (B, ). tempo (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded tempo (B, Tmax). tempo_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded tempo (B, ). beat (Optional[Dict]): key is "lab", "score_phn" or "score_syb"; value (LongTensor): Batch of padded beat (B, Tmax). beat_length (Optional[Dict]): key is "lab", "score_phn" or "score_syb"; value (LongTensor): Batch of the lengths of padded beat (B, ). pitch (FloatTensor): Batch of padded f0 (B, Tmax). pitch_lengths (LongTensor): Batch of the lengths of padded f0 (B, ). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) singing = singing.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False # calculate score2mel outputs score2mel_loss, stats, feats_gen = self.generator["score2mel"]( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, label=label, label_lengths=label_lengths, melody=melody, melody_lengths=melody_lengths, duration=beat, duration_lengths=beat_lengths, pitch=pitch, pitch_lengths=pitch_lengths, sids=sids, spembs=spembs, lids=lids, joint_training=True, ) # get random segments feats_gen_, start_idxs = get_random_segments( x=feats_gen.transpose(1, 2), x_lengths=feats_lengths, segment_size=self.segment_size, ) # calculate vocoder outputs singing_hat_ = self.generator["vocoder"](feats_gen_) if self.use_pqmf: singing_hat_ = self.pqmf.synthesis(singing_hat_) else: _, _, singing_hat_, start_idxs = self._cache # store cache if self.cache_generator_outputs and not reuse_cache: self._cache = (score2mel_loss, stats, singing_hat_, start_idxs) # parse outputs singing_ = get_segments( x=singing, start_idxs=start_idxs * self.generator["vocoder"].upsample_factor, segment_size=self.segment_size * self.generator["vocoder"].upsample_factor, ) # calculate discriminator outputs p_hat = self.discriminator(singing_hat_.detach()) p = self.discriminator(singing_) # calculate losses real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss stats = dict( discriminator_loss=loss.item(), real_loss=real_loss.item(), fake_loss=fake_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 1, # needed for trainer }
[docs] def inference( self, text: torch.Tensor, feats: Optional[torch.Tensor] = None, label: Optional[Dict[str, torch.Tensor]] = None, melody: Optional[Dict[str, torch.Tensor]] = None, pitch: Optional[torch.Tensor] = None, duration: Optional[Dict[str, torch.Tensor]] = None, slur: Optional[Dict[str, torch.Tensor]] = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, noise_scale: float = 0.667, noise_scale_dur: float = 0.8, alpha: float = 1.0, max_len: Optional[int] = None, use_teacher_forcing: bool = False, ) -> Dict[str, torch.Tensor]: """Run inference. Args: text (Tensor): Input text index tensor (T_text,). feats (Tensor): Feature tensor (T_feats, aux_channels). label (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded label ids (B, Tmax). melody (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded melody (B, Tmax). tempo (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded tempo (B, Tmax). beat (Optional[Dict]): key is "lab", "score_phn" or "score_syb"; value (LongTensor): Batch of padded beat (B, Tmax). pitch (FloatTensor): Batch of padded f0 (B, Tmax). duration (Optional[Dict]): key is "phn", "syb"; value (LongTensor): Batch of padded beat (B, Tmax). slur (LongTensor): Batch of padded slur (B, Tmax). sids (Tensor): Speaker index tensor (1,). spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,). lids (Tensor): Language index tensor (1,). noise_scale (float): Noise scale value for flow. noise_scale_dur (float): Noise scale value for duration predictor. alpha (float): Alpha parameter to control the speed of generated singing. max_len (Optional[int]): Maximum length. use_teacher_forcing (bool): Whether to use teacher forcing. Returns: Dict[str, Tensor]: * wav (Tensor): Generated waveform tensor (T_wav,). * feat_gan (Tensor): Generated feature tensor (T_text, C). """ output_dict = self.generator["score2mel"].inference( text=text, feats=feats, label=label, melody=melody, duration=duration, pitch=pitch, sids=sids, spembs=spembs, lids=lids, joint_training=True, ) wav = self.generator["vocoder"].inference(output_dict["feat_gen"]) if self.use_pqmf: wav = self.pqmf.synthesis(wav.unsqueeze(0).transpose(1, 2)) wav = wav.squeeze(0).transpose(0, 1) output_dict.update(wav=wav) return output_dict