# Copyright 2020 Nagoya University (Tomoki Hayashi)
# Copyright 2021 Carnegie Mellon University (Jiatong Shi)
# Copyright 2022 Renmin University of China (Yuning Wu)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Singing-voice-synthesis ESPnet model."""
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict, Optional, Tuple
import torch
from typeguard import check_argument_types
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.layers.inversible_interface import InversibleInterface
from espnet2.svs.abs_svs import AbsSVS
from espnet2.svs.feats_extract.score_feats_extract import (
FrameScoreFeats,
SyllableScoreFeats,
expand_to_frame,
)
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True): # NOQA
yield
[docs]class ESPnetSVSModel(AbsESPnetModel):
"""ESPnet model for singing voice synthesis task."""
def __init__(
self,
text_extract: Optional[AbsFeatsExtract],
feats_extract: Optional[AbsFeatsExtract],
score_feats_extract: Optional[AbsFeatsExtract],
label_extract: Optional[AbsFeatsExtract],
pitch_extract: Optional[AbsFeatsExtract],
ying_extract: Optional[AbsFeatsExtract],
duration_extract: Optional[AbsFeatsExtract],
energy_extract: Optional[AbsFeatsExtract],
normalize: Optional[AbsNormalize and InversibleInterface],
pitch_normalize: Optional[AbsNormalize and InversibleInterface],
energy_normalize: Optional[AbsNormalize and InversibleInterface],
svs: AbsSVS,
):
"""Initialize ESPnetSVSModel module."""
assert check_argument_types()
super().__init__()
self.text_extract = text_extract
self.feats_extract = feats_extract
self.score_feats_extract = score_feats_extract
self.label_extract = label_extract
self.pitch_extract = pitch_extract
self.ying_extract = ying_extract
self.duration_extract = duration_extract
self.energy_extract = energy_extract
self.normalize = normalize
self.pitch_normalize = pitch_normalize
self.energy_normalize = energy_normalize
self.svs = svs
[docs] def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
singing: torch.Tensor,
singing_lengths: torch.Tensor,
feats: Optional[torch.Tensor] = None,
feats_lengths: Optional[torch.Tensor] = None,
label: Optional[torch.Tensor] = None,
label_lengths: Optional[torch.Tensor] = None,
phn_cnt: Optional[torch.Tensor] = None,
midi: Optional[torch.Tensor] = None,
midi_lengths: Optional[torch.Tensor] = None,
duration_phn: Optional[torch.Tensor] = None,
duration_phn_lengths: Optional[torch.Tensor] = None,
duration_ruled_phn: Optional[torch.Tensor] = None,
duration_ruled_phn_lengths: Optional[torch.Tensor] = None,
duration_syb: Optional[torch.Tensor] = None,
duration_syb_lengths: Optional[torch.Tensor] = None,
slur: Optional[torch.Tensor] = None,
slur_lengths: Optional[torch.Tensor] = None,
pitch: Optional[torch.Tensor] = None,
pitch_lengths: Optional[torch.Tensor] = None,
energy: Optional[torch.Tensor] = None,
energy_lengths: Optional[torch.Tensor] = None,
ying: Optional[torch.Tensor] = None,
ying_lengths: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
flag_IsValid=False,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Caclualte outputs and return the loss tensor.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
singing (Tensor): Singing waveform tensor (B, T_wav).
singing_lengths (Tensor): Singing length tensor (B,).
label (Option[Tensor]): Label tensor (B, T_label).
label_lengths (Optional[Tensor]): Label lrngth tensor (B,).
phn_cnt (Optional[Tensor]): Number of phones in each syllable (B, T_syb)
midi (Option[Tensor]): Midi tensor (B, T_label).
midi_lengths (Optional[Tensor]): Midi lrngth tensor (B,).
duration_phn (Optional[Tensor]): duration tensor (B, T_label).
duration_phn_lengths (Optional[Tensor]): duration length tensor (B,).
duration_ruled_phn (Optional[Tensor]): duration tensor (B, T_phone).
duration_ruled_phn_lengths (Optional[Tensor]): duration length tensor (B,).
duration_syb (Optional[Tensor]): duration tensor (B, T_syllable).
duration_syb_lengths (Optional[Tensor]): duration length tensor (B,).
slur (Optional[Tensor]): slur tensor (B, T_slur).
slur_lengths (Optional[Tensor]): slur length tensor (B,).
pitch (Optional[Tensor]): Pitch tensor (B, T_wav). - f0 sequence
pitch_lengths (Optional[Tensor]): Pitch length tensor (B,).
energy (Optional[Tensor]): Energy tensor.
energy_lengths (Optional[Tensor]): Energy length tensor (B,).
spembs (Optional[Tensor]): Speaker embedding tensor (B, D).
sids (Optional[Tensor]): Speaker ID tensor (B, 1).
lids (Optional[Tensor]): Language ID tensor (B, 1).
kwargs: "utt_id" is among the input.
Returns:
Tensor: Loss scalar tensor.
Dict[str, float]: Statistics to be monitored.
Tensor: Weight tensor to summarize losses.
"""
with autocast(False):
# Extract features
if self.feats_extract is not None and feats is None:
feats, feats_lengths = self.feats_extract(
singing, singing_lengths
) # singing to spec feature (frame level)
# Extract auxiliary features
# melody : 128 note pitch
# duration :
# input-> phone-id seqence
# output -> frame level(take mode from window) or syllable level
# cut length
for i in range(feats.size(0)):
dur_len = sum(duration_phn[i])
if feats_lengths[i] > dur_len:
feats_lengths[i] = dur_len
else: # decrease duration at the end of sequence
delta = dur_len - feats_lengths[i]
end = duration_phn_lengths[i] - 1
while delta > 0 and end >= 0:
new = duration_phn[i][end] - delta
if new < 0: # keep on decreasing the previous one
delta -= duration_phn[i][end]
duration_phn[i][end] = 0
end -= 1
else: # stop
delta -= duration_phn[i][end] - new
duration_phn[i][end] = new
feats = feats[:, : feats_lengths.max()]
if isinstance(self.score_feats_extract, FrameScoreFeats):
(
label_lab,
label_lab_lengths,
midi_lab,
midi_lab_lengths,
duration_lab,
duration_lab_lengths,
) = expand_to_frame(
duration_phn, duration_phn_lengths, label, midi, duration_phn
)
# for data-parallel
label_lab = label_lab[:, : label_lab_lengths.max()]
midi_lab = midi_lab[:, : midi_lab_lengths.max()]
duration_lab = duration_lab[:, : duration_lab_lengths.max()]
(
label_score,
label_score_lengths,
midi_score,
midi_score_lengths,
duration_score,
duration_score_phn_lengths,
) = expand_to_frame(
duration_ruled_phn,
duration_ruled_phn_lengths,
label,
midi,
duration_ruled_phn,
)
# for data-parallel
label_score = label_score[:, : label_score_lengths.max()]
midi_score = midi_score[:, : midi_score_lengths.max()]
duration_score = duration_score[:, : duration_score_phn_lengths.max()]
duration_score_syb = None
elif isinstance(self.score_feats_extract, SyllableScoreFeats):
label_lab_lengths = label_lengths
midi_lab_lengths = midi_lengths
duration_lab_lengths = duration_phn_lengths
label_lab = label[:, : label_lab_lengths.max()]
midi_lab = midi[:, : midi_lab_lengths.max()]
duration_lab = duration_phn[:, : duration_lab_lengths.max()]
label_score_lengths = label_lengths
midi_score_lengths = midi_lengths
duration_score_phn_lengths = duration_ruled_phn_lengths
duration_score_syb_lengths = duration_syb_lengths
label_score = label[:, : label_score_lengths.max()]
midi_score = midi[:, : midi_score_lengths.max()]
duration_score = duration_ruled_phn[
:, : duration_score_phn_lengths.max()
]
duration_score_syb = duration_syb[:, : duration_score_syb_lengths.max()]
slur = slur[:, : slur_lengths.max()]
else:
raise RuntimeError("Cannot understand score_feats extract type")
if self.pitch_extract is not None and pitch is None:
pitch, pitch_lengths = self.pitch_extract(
input=singing,
input_lengths=singing_lengths,
feats_lengths=feats_lengths,
)
if self.energy_extract is not None and energy is None:
energy, energy_lengths = self.energy_extract(
singing,
singing_lengths,
feats_lengths=feats_lengths,
)
if self.ying_extract is not None and ying is None:
ying, ying_lengths = self.ying_extract(
singing,
singing_lengths,
feats_lengths=feats_lengths,
)
# Normalize
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
if self.pitch_normalize is not None:
pitch, pitch_lengths = self.pitch_normalize(pitch, pitch_lengths)
if self.energy_normalize is not None:
energy, energy_lengths = self.energy_normalize(energy, energy_lengths)
# Make batch for svs inputs
batch = dict(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
flag_IsValid=flag_IsValid,
)
# label
# NOTE(Yuning): Label can be word, syllable or phoneme,
# which is determined by annotation file.
label = dict()
label_lengths = dict()
if label_lab is not None:
label_lab = label_lab.to(dtype=torch.long)
label.update(lab=label_lab)
label_lengths.update(lab=label_lab_lengths)
if label_score is not None:
label_score = label_score.to(dtype=torch.long)
label.update(score=label_score)
label_lengths.update(score=label_score_lengths)
batch.update(label=label, label_lengths=label_lengths)
# melody
melody = dict()
melody_lengths = dict()
if midi_lab is not None:
midi_lab = midi_lab.to(dtype=torch.long)
melody.update(lab=midi_lab)
melody_lengths.update(lab=midi_lab_lengths)
if midi_score is not None:
midi_score = midi_score.to(dtype=torch.long)
melody.update(score=midi_score)
melody_lengths.update(score=midi_score_lengths)
batch.update(melody=melody, melody_lengths=melody_lengths)
# duration
# NOTE(Yuning): duration = duration_time / time_shift (same as Xiaoice paper)
duration = dict()
duration_lengths = dict()
if duration_lab is not None:
duration_lab = duration_lab.to(dtype=torch.long)
duration.update(lab=duration_lab)
duration_lengths.update(lab=duration_lab_lengths)
if duration_score is not None:
duration_phn_score = duration_score.to(dtype=torch.long)
duration.update(score_phn=duration_phn_score)
duration_lengths.update(score_phn=duration_score_phn_lengths)
if duration_score_syb is not None:
duration_syb_score = duration_score_syb.to(dtype=torch.long)
duration.update(score_syb=duration_syb_score)
duration_lengths.update(score_syb=duration_score_syb_lengths)
batch.update(duration=duration, duration_lengths=duration_lengths)
if slur is not None:
batch.update(slur=slur, slur_lengths=slur_lengths)
if spembs is not None:
batch.update(spembs=spembs)
if sids is not None:
batch.update(sids=sids)
if lids is not None:
batch.update(lids=lids)
if self.pitch_extract is not None and pitch is not None:
batch.update(pitch=pitch, pitch_lengths=pitch_lengths)
if self.energy_extract is not None and energy is not None:
batch.update(energy=energy, energy_lengths=energy_lengths)
if self.ying_extract is not None and ying is not None:
batch.update(ying=ying)
if self.svs.require_raw_singing:
batch.update(singing=singing, singing_lengths=singing_lengths)
return self.svs(**batch)
[docs] def collect_feats(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
singing: torch.Tensor,
singing_lengths: torch.Tensor,
label: Optional[torch.Tensor] = None,
label_lengths: Optional[torch.Tensor] = None,
phn_cnt: Optional[torch.Tensor] = None,
midi: Optional[torch.Tensor] = None,
midi_lengths: Optional[torch.Tensor] = None,
duration_phn: Optional[torch.Tensor] = None,
duration_phn_lengths: Optional[torch.Tensor] = None,
duration_ruled_phn: Optional[torch.Tensor] = None,
duration_ruled_phn_lengths: Optional[torch.Tensor] = None,
duration_syb: Optional[torch.Tensor] = None,
duration_syb_lengths: Optional[torch.Tensor] = None,
slur: Optional[torch.Tensor] = None,
slur_lengths: Optional[torch.Tensor] = None,
pitch: Optional[torch.Tensor] = None,
pitch_lengths: Optional[torch.Tensor] = None,
energy: Optional[torch.Tensor] = None,
energy_lengths: Optional[torch.Tensor] = None,
ying: Optional[torch.Tensor] = None,
ying_lengths: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Caclualte features and return them as a dict.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
singing (Tensor): Singing waveform tensor (B, T_wav).
singing_lengths (Tensor): Singing length tensor (B,).
label (Option[Tensor]): Label tensor (B, T_label).
label_lengths (Optional[Tensor]): Label lrngth tensor (B,).
phn_cnt (Optional[Tensor]): Number of phones in each syllable (B, T_syb)
midi (Option[Tensor]): Midi tensor (B, T_label).
midi_lengths (Optional[Tensor]): Midi lrngth tensor (B,).
---- duration* is duration in time_shift ----
duration_phn (Optional[Tensor]): duration tensor (B, T_label).
duration_phn_lengths (Optional[Tensor]): duration length tensor (B,).
duration_ruled_phn (Optional[Tensor]): duration tensor (B, T_phone).
duration_ruled_phn_lengths (Optional[Tensor]): duration length tensor (B,).
duration_syb (Optional[Tensor]): duration tensor (B, T_syb).
duration_syb_lengths (Optional[Tensor]): duration length tensor (B,).
slur (Optional[Tensor]): slur tensor (B, T_slur).
slur_lengths (Optional[Tensor]): slur length tensor (B,).
pitch (Optional[Tensor]): Pitch tensor (B, T_wav). - f0 sequence
pitch_lengths (Optional[Tensor]): Pitch length tensor (B,).
energy (Optional[Tensor): Energy tensor.
energy_lengths (Optional[Tensor): Energy length tensor (B,).
spembs (Optional[Tensor]): Speaker embedding tensor (B, D).
sids (Optional[Tensor]): Speaker ID tensor (B, 1).
lids (Optional[Tensor]): Language ID tensor (B, 1).
Returns:
Dict[str, Tensor]: Dict of features.
"""
feats = None
if self.feats_extract is not None:
feats, feats_lengths = self.feats_extract(singing, singing_lengths)
else:
# Use precalculated feats (feats_type != raw case)
feats, feats_lengths = singing, singing_lengths
# cut length
for i in range(feats.size(0)):
dur_len = sum(duration_phn[i])
if feats_lengths[i] > dur_len:
feats_lengths[i] = dur_len
else: # decrease duration at the end of sequence
delta = dur_len - feats_lengths[i]
end = duration_phn_lengths[i] - 1
while delta > 0 and end >= 0:
new = duration_phn[i][end] - delta
if new < 0: # keep on decreasing the previous one
delta -= duration_phn[i][end]
duration_phn[i][end] = 0
end -= 1
else: # stop
delta -= duration_phn[i][end] - new
duration_phn[i][end] = new
feats = feats[:, : feats_lengths.max()]
if self.pitch_extract is not None:
pitch, pitch_lengths = self.pitch_extract(
input=singing,
input_lengths=singing_lengths,
feats_lengths=feats_lengths,
)
if self.energy_extract is not None:
energy, energy_lengths = self.energy_extract(
singing,
singing_lengths,
feats_lengths=feats_lengths,
)
if self.ying_extract is not None and ying is None:
ying, ying_lengths = self.ying_extract(
singing,
singing_lengths,
feats_lengths=feats_lengths,
)
# store in dict
feats_dict = {}
if feats is not None:
feats_dict.update(feats=feats, feats_lengths=feats_lengths)
if pitch is not None:
feats_dict.update(pitch=pitch, pitch_lengths=pitch_lengths)
if energy is not None:
feats_dict.update(energy=energy, energy_lengths=energy_lengths)
if ying is not None:
feats_dict.update(ying=ying, ying_lengths=ying_lengths)
return feats_dict
[docs] def inference(
self,
text: torch.Tensor,
singing: Optional[torch.Tensor] = None,
label: Optional[torch.Tensor] = None,
phn_cnt: Optional[torch.Tensor] = None,
midi: Optional[torch.Tensor] = None,
duration_phn: Optional[torch.Tensor] = None,
duration_ruled_phn: Optional[torch.Tensor] = None,
duration_syb: Optional[torch.Tensor] = None,
slur: Optional[torch.Tensor] = None,
pitch: Optional[torch.Tensor] = None,
energy: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
**decode_config,
) -> Dict[str, torch.Tensor]:
"""Caclualte features and return them as a dict.
Args:
text (Tensor): Text index tensor (T_text).
singing (Tensor): Singing waveform tensor (T_wav).
label (Option[Tensor]): Label tensor (T_label).
phn_cnt (Optional[Tensor]): Number of phones in each syllable (T_syb)
midi (Option[Tensor]): Midi tensor (T_l abel).
duration_phn (Optional[Tensor]): duration tensor (T_label).
duration_ruled_phn (Optional[Tensor]): duration tensor (T_phone).
duration_syb (Optional[Tensor]): duration tensor (T_phone).
slur (Optional[Tensor]): slur tensor (T_phone).
spembs (Optional[Tensor]): Speaker embedding tensor (D,).
sids (Optional[Tensor]): Speaker ID tensor (1,).
lids (Optional[Tensor]): Language ID tensor (1,).
pitch (Optional[Tensor): Pitch tensor (T_wav).
energy (Optional[Tensor): Energy tensor.
Returns:
Dict[str, Tensor]: Dict of outputs.
"""
label_lengths = torch.tensor([len(label)])
midi_lengths = torch.tensor([len(midi)])
duration_phn_lengths = torch.tensor([len(duration_phn)])
duration_ruled_phn_lengths = torch.tensor([len(duration_ruled_phn)])
duration_syb_lengths = torch.tensor([len(duration_syb)])
slur_lengths = torch.tensor([len(slur)])
# unsqueeze of singing needed otherwise causing error in STFT dimension
# for data-parallel
text = text.unsqueeze(0)
label = label.unsqueeze(0)
midi = midi.unsqueeze(0)
duration_phn = duration_phn.unsqueeze(0)
duration_ruled_phn = duration_ruled_phn.unsqueeze(0)
duration_syb = duration_syb.unsqueeze(0)
phn_cnt = phn_cnt.unsqueeze(0)
slur = slur.unsqueeze(0)
# Extract auxiliary features
# melody : 128 midi pitch
# duration :
# input-> phone-id seqence
# output -> frame level or syllable level
batch_size = text.size(0)
assert batch_size == 1
if isinstance(self.score_feats_extract, FrameScoreFeats):
(
label_lab,
label_lab_lengths,
midi_lab,
midi_lab_lengths,
duration_lab,
duration_lab_lengths,
) = expand_to_frame(
duration_phn, duration_phn_lengths, label, midi, duration_phn
)
# for data-parallel
label_lab = label_lab[:, : label_lab_lengths.max()]
midi_lab = midi_lab[:, : midi_lab_lengths.max()]
duration_lab = duration_lab[:, : duration_lab_lengths.max()]
(
label_score,
label_score_lengths,
midi_score,
midi_score_lengths,
duration_score,
duration_score_phn_lengths,
) = expand_to_frame(
duration_ruled_phn,
duration_ruled_phn_lengths,
label,
midi,
duration_ruled_phn,
)
# for data-parallel
label_score = label_score[:, : label_score_lengths.max()]
midi_score = midi_score[:, : midi_score_lengths.max()]
duration_score = duration_score[:, : duration_score_phn_lengths.max()]
duration_score_syb = None
elif isinstance(self.score_feats_extract, SyllableScoreFeats):
# Remove unused paddings at end
label_lab = label[:, : label_lengths.max()]
midi_lab = midi[:, : midi_lengths.max()]
duration_lab = duration_phn[:, : duration_phn_lengths.max()]
label_score = label[:, : label_lengths.max()]
midi_score = midi[:, : midi_lengths.max()]
duration_score = duration_ruled_phn[:, : duration_ruled_phn_lengths.max()]
duration_score_syb = duration_syb[:, : duration_syb_lengths.max()]
slur = slur[:, : slur_lengths.max()]
input_dict = dict(text=text)
if decode_config["use_teacher_forcing"] or getattr(self.svs, "use_gst", False):
if singing is None:
raise RuntimeError("missing required argument: 'singing'")
if self.feats_extract is not None:
feats = self.feats_extract(singing[None])[0][0]
else:
# Use precalculated feats (feats_type != raw case)
feats = singing
if self.normalize is not None:
feats = self.normalize(feats[None])[0][0]
input_dict.update(feats=feats)
# if self.svs.require_raw_singing:
# input_dict.update(singing=singing)
if decode_config["use_teacher_forcing"]:
if self.pitch_extract is not None:
pitch = self.pitch_extract(
singing[None],
feats_lengths=torch.LongTensor([len(feats)]),
)[0][0]
if self.pitch_normalize is not None:
pitch = self.pitch_normalize(pitch[None])[0][0]
if pitch is not None:
input_dict.update(pitch=pitch)
if self.energy_extract is not None:
energy = self.energy_extract(
singing[None],
feats_lengths=torch.LongTensor([len(feats)]),
)[0][0]
if self.energy_normalize is not None:
energy = self.energy_normalize(energy[None])[0][0]
if energy is not None:
input_dict.update(energy=energy)
# label
label = dict()
if label_lab is not None:
label_lab = label_lab.to(dtype=torch.long)
label.update(lab=label_lab)
if label_score is not None:
label_score = label_score.to(dtype=torch.long)
label.update(score=label_score)
input_dict.update(label=label)
# melody
melody = dict()
if midi_lab is not None:
midi_lab = midi_lab.to(dtype=torch.long)
melody.update(lab=midi_lab)
if midi_score is not None:
midi_score = midi_score.to(dtype=torch.long)
melody.update(score=midi_score)
input_dict.update(melody=melody)
# duration
duration = dict()
if duration_lab is not None:
duration_lab = duration_lab.to(dtype=torch.long)
duration.update(lab=duration_lab)
if duration_score is not None:
duration_phn_score = duration_score.to(dtype=torch.long)
duration.update(score_phn=duration_phn_score)
if duration_score_syb is not None:
duration_syb_score = duration_score_syb.to(dtype=torch.long)
duration.update(score_syb=duration_syb_score)
input_dict.update(duration=duration)
if slur is not None:
input_dict.update(slur=slur)
if spembs is not None:
input_dict.update(spembs=spembs)
if sids is not None:
input_dict.update(sids=sids)
if lids is not None:
input_dict.update(lids=lids)
output_dict = self.svs.inference(**input_dict, **decode_config)
if self.normalize is not None and output_dict.get("feat_gen") is not None:
# NOTE: normalize.inverse is in-place operation
feat_gen_denorm = self.normalize.inverse(
output_dict["feat_gen"].clone()[None]
)[0][0]
output_dict.update(feat_gen_denorm=feat_gen_denorm)
return output_dict