# Copyright 2020 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Text-to-speech ESPnet model."""
from contextlib import contextmanager
from typing import Dict, Optional, Tuple
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.layers.inversible_interface import InversibleInterface
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet2.tts.abs_tts import AbsTTS
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True): # NOQA
yield
[docs]class ESPnetTTSModel(AbsESPnetModel):
"""ESPnet model for text-to-speech task."""
def __init__(
self,
feats_extract: Optional[AbsFeatsExtract],
pitch_extract: Optional[AbsFeatsExtract],
energy_extract: Optional[AbsFeatsExtract],
normalize: Optional[AbsNormalize and InversibleInterface],
pitch_normalize: Optional[AbsNormalize and InversibleInterface],
energy_normalize: Optional[AbsNormalize and InversibleInterface],
tts: AbsTTS,
):
"""Initialize ESPnetTTSModel module."""
assert check_argument_types()
super().__init__()
self.feats_extract = feats_extract
self.pitch_extract = pitch_extract
self.energy_extract = energy_extract
self.normalize = normalize
self.pitch_normalize = pitch_normalize
self.energy_normalize = energy_normalize
self.tts = tts
[docs] def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
durations: Optional[torch.Tensor] = None,
durations_lengths: Optional[torch.Tensor] = None,
pitch: Optional[torch.Tensor] = None,
pitch_lengths: Optional[torch.Tensor] = None,
energy: Optional[torch.Tensor] = None,
energy_lengths: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Caclualte outputs and return the loss tensor.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
duration (Optional[Tensor]): Duration tensor.
duration_lengths (Optional[Tensor]): Duration length tensor (B,).
pitch (Optional[Tensor]): Pitch tensor.
pitch_lengths (Optional[Tensor]): Pitch length tensor (B,).
energy (Optional[Tensor]): Energy tensor.
energy_lengths (Optional[Tensor]): Energy length tensor (B,).
spembs (Optional[Tensor]): Speaker embedding tensor (B, D).
sids (Optional[Tensor]): Speaker ID tensor (B, 1).
lids (Optional[Tensor]): Language ID tensor (B, 1).
kwargs: "utt_id" is among the input.
Returns:
Tensor: Loss scalar tensor.
Dict[str, float]: Statistics to be monitored.
Tensor: Weight tensor to summarize losses.
"""
with autocast(False):
# Extract features
if self.feats_extract is not None:
feats, feats_lengths = self.feats_extract(speech, speech_lengths)
else:
# Use precalculated feats (feats_type != raw case)
feats, feats_lengths = speech, speech_lengths
# Extract auxiliary features
if self.pitch_extract is not None and pitch is None:
pitch, pitch_lengths = self.pitch_extract(
speech,
speech_lengths,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
)
if self.energy_extract is not None and energy is None:
energy, energy_lengths = self.energy_extract(
speech,
speech_lengths,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
)
# Normalize
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
if self.pitch_normalize is not None:
pitch, pitch_lengths = self.pitch_normalize(pitch, pitch_lengths)
if self.energy_normalize is not None:
energy, energy_lengths = self.energy_normalize(energy, energy_lengths)
# Make batch for tts inputs
batch = dict(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
)
# Update batch for additional auxiliary inputs
if spembs is not None:
batch.update(spembs=spembs)
if sids is not None:
batch.update(sids=sids)
if lids is not None:
batch.update(lids=lids)
if durations is not None:
batch.update(durations=durations, durations_lengths=durations_lengths)
if self.pitch_extract is not None and pitch is not None:
batch.update(pitch=pitch, pitch_lengths=pitch_lengths)
if self.energy_extract is not None and energy is not None:
batch.update(energy=energy, energy_lengths=energy_lengths)
if self.tts.require_raw_speech:
batch.update(speech=speech, speech_lengths=speech_lengths)
return self.tts(**batch)
[docs] def collect_feats(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
durations: Optional[torch.Tensor] = None,
durations_lengths: Optional[torch.Tensor] = None,
pitch: Optional[torch.Tensor] = None,
pitch_lengths: Optional[torch.Tensor] = None,
energy: Optional[torch.Tensor] = None,
energy_lengths: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Caclualte features and return them as a dict.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
durations (Optional[Tensor): Duration tensor.
durations_lengths (Optional[Tensor): Duration length tensor (B,).
pitch (Optional[Tensor): Pitch tensor.
pitch_lengths (Optional[Tensor): Pitch length tensor (B,).
energy (Optional[Tensor): Energy tensor.
energy_lengths (Optional[Tensor): Energy length tensor (B,).
spembs (Optional[Tensor]): Speaker embedding tensor (B, D).
sids (Optional[Tensor]): Speaker ID tensor (B, 1).
lids (Optional[Tensor]): Language ID tensor (B, 1).
Returns:
Dict[str, Tensor]: Dict of features.
"""
# feature extraction
if self.feats_extract is not None:
feats, feats_lengths = self.feats_extract(speech, speech_lengths)
else:
# Use precalculated feats (feats_type != raw case)
feats, feats_lengths = speech, speech_lengths
if self.pitch_extract is not None:
pitch, pitch_lengths = self.pitch_extract(
speech,
speech_lengths,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
)
if self.energy_extract is not None:
energy, energy_lengths = self.energy_extract(
speech,
speech_lengths,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
)
# store in dict
feats_dict = dict(feats=feats, feats_lengths=feats_lengths)
if pitch is not None:
feats_dict.update(pitch=pitch, pitch_lengths=pitch_lengths)
if energy is not None:
feats_dict.update(energy=energy, energy_lengths=energy_lengths)
return feats_dict
[docs] def inference(
self,
text: torch.Tensor,
speech: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
durations: Optional[torch.Tensor] = None,
pitch: Optional[torch.Tensor] = None,
energy: Optional[torch.Tensor] = None,
**decode_config,
) -> Dict[str, torch.Tensor]:
"""Caclualte features and return them as a dict.
Args:
text (Tensor): Text index tensor (T_text).
speech (Tensor): Speech waveform tensor (T_wav).
spembs (Optional[Tensor]): Speaker embedding tensor (D,).
sids (Optional[Tensor]): Speaker ID tensor (1,).
lids (Optional[Tensor]): Language ID tensor (1,).
durations (Optional[Tensor): Duration tensor.
pitch (Optional[Tensor): Pitch tensor.
energy (Optional[Tensor): Energy tensor.
Returns:
Dict[str, Tensor]: Dict of outputs.
"""
input_dict = dict(text=text)
if decode_config["use_teacher_forcing"] or getattr(self.tts, "use_gst", False):
if speech is None:
raise RuntimeError("missing required argument: 'speech'")
if self.feats_extract is not None:
feats = self.feats_extract(speech[None])[0][0]
else:
# Use precalculated feats (feats_type != raw case)
feats = speech
if self.normalize is not None:
feats = self.normalize(feats[None])[0][0]
input_dict.update(feats=feats)
if self.tts.require_raw_speech:
input_dict.update(speech=speech)
if decode_config["use_teacher_forcing"]:
if durations is not None:
input_dict.update(durations=durations)
if self.pitch_extract is not None:
pitch = self.pitch_extract(
speech[None],
feats_lengths=torch.LongTensor([len(feats)]),
durations=durations[None],
)[0][0]
if self.pitch_normalize is not None:
pitch = self.pitch_normalize(pitch[None])[0][0]
if pitch is not None:
input_dict.update(pitch=pitch)
if self.energy_extract is not None:
energy = self.energy_extract(
speech[None],
feats_lengths=torch.LongTensor([len(feats)]),
durations=durations[None],
)[0][0]
if self.energy_normalize is not None:
energy = self.energy_normalize(energy[None])[0][0]
if energy is not None:
input_dict.update(energy=energy)
if spembs is not None:
input_dict.update(spembs=spembs)
if sids is not None:
input_dict.update(sids=sids)
if lids is not None:
input_dict.update(lids=lids)
output_dict = self.tts.inference(**input_dict, **decode_config)
if self.normalize is not None and output_dict.get("feat_gen") is not None:
# NOTE: normalize.inverse is in-place operation
feat_gen_denorm = self.normalize.inverse(
output_dict["feat_gen"].clone()[None]
)[0][0]
output_dict.update(feat_gen_denorm=feat_gen_denorm)
return output_dict