# Copyright 2021 Carnegie Mellon University (Jiatong Shi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Naive-SVS related modules."""
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from espnet2.svs.abs_svs import AbsSVS
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.torch_utils.initialize import initialize
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet
from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder as EncoderPrenet
[docs]class NaiveRNNLoss(torch.nn.Module):
"""Loss function module for Tacotron2."""
def __init__(self, use_masking=True, use_weighted_masking=False):
"""Initialize Tactoron2 loss module.
Args:
use_masking (bool): Whether to apply masking
for padded part in loss calculation.
use_weighted_masking (bool):
Whether to apply weighted masking in loss calculation.
"""
super(NaiveRNNLoss, self).__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.l1_criterion = torch.nn.L1Loss(reduction=reduction)
self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
# NOTE(kan-bayashi): register pre hook function for the compatibility
# self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
[docs] def forward(self, after_outs, before_outs, ys, olens):
"""Calculate forward propagation.
Args:
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
Returns:
Tensor: L1 loss value.
Tensor: Mean square error loss value.
"""
# make mask and apply it
if self.use_masking:
masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
ys = ys.masked_select(masks)
after_outs = after_outs.masked_select(masks)
before_outs = before_outs.masked_select(masks)
# calculate loss
l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys)
mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion(
before_outs, ys
)
# make weighted mask and apply it
if self.use_weighted_masking:
masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
weights = masks.float() / masks.sum(dim=1, keepdim=True).float()
out_weights = weights.div(ys.size(0) * ys.size(2))
# apply weight
l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum()
mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum()
return l1_loss, mse_loss
[docs]class NaiveRNN(AbsSVS):
"""NaiveRNN-SVS module.
This is an implementation of naive RNN for singing voice synthesis
The features are processed directly over time-domain from music score and
predict the singing voice features
"""
def __init__(
self,
# network structure related
idim: int,
odim: int,
midi_dim: int = 129,
embed_dim: int = 512,
eprenet_conv_layers: int = 3,
eprenet_conv_chans: int = 256,
eprenet_conv_filts: int = 5,
elayers: int = 3,
eunits: int = 1024,
ebidirectional: bool = True,
midi_embed_integration_type: str = "add",
dlayers: int = 3,
dunits: int = 1024,
dbidirectional: bool = True,
postnet_layers: int = 5,
postnet_chans: int = 256,
postnet_filts: int = 5,
use_batch_norm: bool = True,
reduction_factor: int = 1,
# extra embedding related
spks: Optional[int] = None,
langs: Optional[int] = None,
spk_embed_dim: Optional[int] = None,
spk_embed_integration_type: str = "add",
eprenet_dropout_rate: float = 0.5,
edropout_rate: float = 0.1,
ddropout_rate: float = 0.1,
postnet_dropout_rate: float = 0.5,
init_type: str = "xavier_uniform",
use_masking: bool = False,
use_weighted_masking: bool = False,
loss_type: str = "L1",
):
"""Initialize NaiveRNN module.
Args: TODO(Yuning)
"""
assert check_argument_types()
super().__init__()
# store hyperparameters
self.idim = idim
self.midi_dim = midi_dim
self.eunits = eunits
self.odim = odim
self.eos = idim - 1
self.reduction_factor = reduction_factor
self.loss_type = loss_type
self.midi_embed_integration_type = midi_embed_integration_type
# use idx 0 as padding idx
self.padding_idx = 0
# define transformer encoder
if eprenet_conv_layers != 0:
# encoder prenet
self.encoder_input_layer = torch.nn.Sequential(
EncoderPrenet(
idim=idim,
embed_dim=embed_dim,
elayers=0,
econv_layers=eprenet_conv_layers,
econv_chans=eprenet_conv_chans,
econv_filts=eprenet_conv_filts,
use_batch_norm=use_batch_norm,
dropout_rate=eprenet_dropout_rate,
padding_idx=self.padding_idx,
),
torch.nn.Linear(eprenet_conv_chans, eunits),
)
self.midi_encoder_input_layer = torch.nn.Sequential(
EncoderPrenet(
idim=midi_dim,
embed_dim=embed_dim,
elayers=0,
econv_layers=eprenet_conv_layers,
econv_chans=eprenet_conv_chans,
econv_filts=eprenet_conv_filts,
use_batch_norm=use_batch_norm,
dropout_rate=eprenet_dropout_rate,
padding_idx=self.padding_idx,
),
torch.nn.Linear(eprenet_conv_chans, eunits),
)
else:
self.encoder_input_layer = torch.nn.Embedding(
num_embeddings=idim, embedding_dim=eunits, padding_idx=self.padding_idx
)
self.midi_encoder_input_layer = torch.nn.Embedding(
num_embeddings=midi_dim,
embedding_dim=eunits,
padding_idx=self.padding_idx,
)
self.encoder = torch.nn.LSTM(
input_size=eunits,
hidden_size=eunits,
num_layers=elayers,
batch_first=True,
dropout=edropout_rate,
bidirectional=ebidirectional,
)
self.midi_encoder = torch.nn.LSTM(
input_size=eunits,
hidden_size=eunits,
num_layers=elayers,
batch_first=True,
dropout=edropout_rate,
bidirectional=ebidirectional,
)
dim_direction = 2 if ebidirectional is True else 1
if self.midi_embed_integration_type == "add":
self.midi_projection = torch.nn.Linear(
eunits * dim_direction, eunits * dim_direction
)
else:
self.midi_projection = torch.nn.linear(
2 * eunits * dim_direction, eunits * dim_direction
)
self.decoder = torch.nn.LSTM(
input_size=eunits,
hidden_size=eunits,
num_layers=dlayers,
batch_first=True,
dropout=ddropout_rate,
bidirectional=dbidirectional,
)
# define spk and lang embedding
self.spks = None
if spks is not None and spks > 1:
self.spks = spks
self.sid_emb = torch.nn.Embedding(spks, eunits * dim_direction)
self.langs = None
if langs is not None and langs > 1:
# TODO(Yuning): not encode yet
self.langs = langs
self.lid_emb = torch.nn.Embedding(langs, eunits * dim_direction)
# define projection layer
self.spk_embed_dim = None
if spk_embed_dim is not None and spk_embed_dim > 0:
self.spk_embed_dim = spk_embed_dim
self.spk_embed_integration_type = spk_embed_integration_type
if self.spk_embed_dim is not None:
if self.spk_embed_integration_type == "add":
self.projection = torch.nn.Linear(
self.spk_embed_dim, eunits * dim_direction
)
else:
self.projection = torch.nn.Linear(
eunits * dim_direction + self.spk_embed_dim, eunits * dim_direction
)
# define final projection
self.feat_out = torch.nn.Linear(eunits * dim_direction, odim * reduction_factor)
# define postnet
self.postnet = (
None
if postnet_layers == 0
else Postnet(
idim=idim,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=use_batch_norm,
dropout_rate=postnet_dropout_rate,
)
)
# define loss function
self.criterion = NaiveRNNLoss(
use_masking=use_masking,
use_weighted_masking=use_weighted_masking,
)
# initialize parameters
self._reset_parameters(
init_type=init_type,
)
def _reset_parameters(self, init_type):
# initialize parameters
if init_type != "pytorch":
initialize(self, init_type)
[docs] def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
label: Optional[Dict[str, torch.Tensor]] = None,
label_lengths: Optional[Dict[str, torch.Tensor]] = None,
melody: Optional[Dict[str, torch.Tensor]] = None,
melody_lengths: Optional[Dict[str, torch.Tensor]] = None,
pitch: Optional[torch.Tensor] = None,
pitch_lengths: Optional[torch.Tensor] = None,
duration: Optional[Dict[str, torch.Tensor]] = None,
duration_lengths: Optional[Dict[str, torch.Tensor]] = None,
slur: torch.LongTensor = None,
slur_lengths: torch.Tensor = None,
spembs: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
flag_IsValid=False,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Calculate forward propagation.
Args:
text (LongTensor): Batch of padded character ids (B, Tmax).
text_lengths (LongTensor): Batch of lengths of each input batch (B,).
feats (Tensor): Batch of padded target features (B, Lmax, odim).
feats_lengths (LongTensor): Batch of the lengths of each target (B,).
label (Optional[Dict]): key is "lab" or "score";
value (LongTensor): Batch of padded label ids (B, Tmax).
label_lengths (Optional[Dict]): key is "lab" or "score";
value (LongTensor): Batch of the lengths of padded label ids (B, ).
melody (Optional[Dict]): key is "lab" or "score";
value (LongTensor): Batch of padded melody (B, Tmax).
melody_lengths (Optional[Dict]): key is "lab" or "score";
value (LongTensor): Batch of the lengths of padded melody (B, ).
pitch (FloatTensor): Batch of padded f0 (B, Tmax).
pitch_lengths (LongTensor): Batch of the lengths of padded f0 (B, ).
duration (Optional[Dict]): key is "lab", "score";
value (LongTensor): Batch of padded duration (B, Tmax).
duration_lengths (Optional[Dict]): key is "lab" or "score";
value (LongTensor): Batch of the lengths of padded duration (B, ).
slur (LongTensor): Batch of padded slur (B, Tmax).
slur_lengths (LongTensor): Batch of the lengths of padded slur (B, ).
spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim).
sids (Optional[Tensor]): Batch of speaker IDs (B, 1).
lids (Optional[Tensor]): Batch of language IDs (B, 1).
GS Fix:
arguements from forward func. V.S. **batch from espnet_model.py
label == durations | phone sequence
melody -> pitch sequence
Returns:
Tensor: Loss scalar value.
Dict: Statistics to be monitored.
Tensor: Weight value if not joint training else model outputs.
"""
label = label["lab"]
midi = melody["lab"]
label_lengths = label_lengths["lab"]
midi_lengths = melody_lengths["lab"]
text = text[:, : text_lengths.max()] # for data-parallel
feats = feats[:, : feats_lengths.max()] # for data-parallel
midi = midi[:, : midi_lengths.max()] # for data-parallel
label = label[:, : label_lengths.max()] # for data-parallel
batch_size = feats.size(0)
label_emb = self.encoder_input_layer(label) # FIX ME: label Float to Int
midi_emb = self.midi_encoder_input_layer(midi)
label_emb = torch.nn.utils.rnn.pack_padded_sequence(
label_emb, label_lengths.to("cpu"), batch_first=True, enforce_sorted=False
)
midi_emb = torch.nn.utils.rnn.pack_padded_sequence(
midi_emb, midi_lengths.to("cpu"), batch_first=True, enforce_sorted=False
)
hs_label, (_, _) = self.encoder(label_emb)
hs_midi, (_, _) = self.midi_encoder(midi_emb)
hs_label, _ = torch.nn.utils.rnn.pad_packed_sequence(hs_label, batch_first=True)
hs_midi, _ = torch.nn.utils.rnn.pad_packed_sequence(hs_midi, batch_first=True)
if self.midi_embed_integration_type == "add":
hs = hs_label + hs_midi
hs = F.leaky_relu(self.midi_projection(hs))
else:
hs = torch.cat((hs_label, hs_midi), dim=-1)
hs = F.leaky_relu(self.midi_projection(hs))
# integrate spk & lang embeddings
if self.spks is not None:
sid_embs = self.sid_emb(sids.view(-1))
hs = hs + sid_embs.unsqueeze(1)
if self.langs is not None:
lid_embs = self.lid_emb(lids.view(-1))
hs = hs + lid_embs.unsqueeze(1)
# integrate speaker embedding
if self.spk_embed_dim is not None:
hs = self._integrate_with_spk_embed(hs, spembs)
# (B, T_feats//r, odim * r) -> (B, T_feats//r * r, odim)
before_outs = F.leaky_relu(self.feat_out(hs).view(hs.size(0), -1, self.odim))
# postnet -> (B, T_feats//r * r, odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(
before_outs.transpose(1, 2)
).transpose(1, 2)
# modifiy mod part of groundtruth
if self.reduction_factor > 1:
assert feats_lengths.ge(
self.reduction_factor
).all(), "Output length must be greater than or equal to reduction factor."
olens = feats_lengths.new(
[olen - olen % self.reduction_factor for olen in feats_lengths]
)
max_olen = max(olens)
ys = feats[:, :max_olen]
else:
ys = feats
olens = feats_lengths
# calculate loss values
l1_loss, l2_loss = self.criterion(
after_outs[:, : olens.max()], before_outs[:, : olens.max()], ys, olens
)
if self.loss_type == "L1":
loss = l1_loss
elif self.loss_type == "L2":
loss = l2_loss
elif self.loss_type == "L1+L2":
loss = l1_loss + l2_loss
else:
raise ValueError("unknown --loss-type " + self.loss_type)
stats = dict(
loss=loss.item(),
l1_loss=l1_loss.item(),
l2_loss=l2_loss.item(),
)
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
if flag_IsValid is False:
# training stage
return loss, stats, weight
else:
# validation stage
return loss, stats, weight, after_outs[:, : olens.max()], ys, olens
[docs] def inference(
self,
text: torch.Tensor,
feats: Optional[torch.Tensor] = None,
label: Optional[Dict[str, torch.Tensor]] = None,
melody: Optional[Dict[str, torch.Tensor]] = None,
pitch: Optional[torch.Tensor] = None,
duration: Optional[Dict[str, torch.Tensor]] = None,
slur: Optional[Dict[str, torch.Tensor]] = None,
spembs: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
use_teacher_forcing: torch.Tensor = False,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Calculate forward propagation.
Args:
text (LongTensor): Batch of padded character ids (Tmax).
feats (Tensor): Batch of padded target features (Lmax, odim).
label (Optional[Dict]): key is "lab" or "score";
value (LongTensor): Batch of padded label ids (Tmax).
melody (Optional[Dict]): key is "lab" or "score";
value (LongTensor): Batch of padded melody (Tmax).
pitch (FloatTensor): Batch of padded f0 (Tmax).
slur (LongTensor): Batch of padded slur (B, Tmax).
duration (Optional[Dict]): key is "lab", "score";
value (LongTensor): Batch of padded duration (Tmax).
spembs (Optional[Tensor]): Batch of speaker embeddings (spk_embed_dim).
sids (Optional[Tensor]): Batch of speaker IDs (1).
lids (Optional[Tensor]): Batch of language IDs (1).
Returns:
Dict[str, Tensor]: Output dict including the following items:
* feat_gen (Tensor): Output sequence of features (T_feats, odim).
"""
label = label["lab"]
midi = melody["lab"]
label_emb = self.encoder_input_layer(label)
midi_emb = self.midi_encoder_input_layer(midi)
hs_label, (_, _) = self.encoder(label_emb)
hs_midi, (_, _) = self.midi_encoder(midi_emb)
if self.midi_embed_integration_type == "add":
hs = hs_label + hs_midi
hs = F.leaky_relu(self.midi_projection(hs))
else:
hs = torch.cat((hs_label, hs_midi), dim=-1)
hs = F.leaky_relu(self.midi_projection(hs))
# integrate spk & lang embeddings
if self.spks is not None:
sid_embs = self.sid_emb(sids.view(-1))
hs = hs + sid_embs.unsqueeze(1)
if self.langs is not None:
lid_embs = self.lid_emb(lids.view(-1))
hs = hs + lid_embs.unsqueeze(1)
# integrate speaker embedding
if self.spk_embed_dim is not None:
hs = self._integrate_with_spk_embed(hs, spembs)
# (B, T_feats//r, odim * r) -> (B, T_feats//r * r, odim)
before_outs = F.leaky_relu(self.feat_out(hs).view(hs.size(0), -1, self.odim))
# postnet -> (B, T_feats//r * r, odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(
before_outs.transpose(1, 2)
).transpose(1, 2)
return dict(
feat_gen=after_outs[0], prob=None, att_w=None
) # outs, probs, att_ws
def _integrate_with_spk_embed(
self, hs: torch.Tensor, spembs: torch.Tensor
) -> torch.Tensor:
"""Integrate speaker embedding with hidden states.
Args:
hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim).
Returns:
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim).
"""
if self.spk_embed_integration_type == "add":
# apply projection and then add to hidden states
spembs = self.projection(F.normalize(spembs))
hs = hs + spembs.unsqueeze(1)
elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
hs = self.projection(torch.cat([hs, spembs], dim=-1))
else:
raise NotImplementedError("support only add or concat.")
return hs