# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""GAN-based text-to-speech task."""
import argparse
import logging
from typing import Callable, Collection, Dict, List, Optional, Tuple
import numpy as np
import torch
from typeguard import check_argument_types, check_return_type
from espnet2.gan_tts.abs_gan_tts import AbsGANTTS
from espnet2.gan_tts.espnet_model import ESPnetGANTTSModel
from espnet2.gan_tts.jets import JETS
from espnet2.gan_tts.joint import JointText2Wav
from espnet2.gan_tts.vits import VITS
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.layers.global_mvn import GlobalMVN
from espnet2.layers.utterance_mvn import UtteranceMVN
from espnet2.tasks.abs_task import AbsTask, optim_classes
from espnet2.text.phoneme_tokenizer import g2p_choices
from espnet2.train.class_choices import ClassChoices
from espnet2.train.collate_fn import CommonCollateFn
from espnet2.train.gan_trainer import GANTrainer
from espnet2.train.preprocessor import CommonPreprocessor
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract
from espnet2.tts.feats_extract.dio import Dio
from espnet2.tts.feats_extract.energy import Energy
from espnet2.tts.feats_extract.linear_spectrogram import LinearSpectrogram
from espnet2.tts.feats_extract.log_mel_fbank import LogMelFbank
from espnet2.tts.feats_extract.log_spectrogram import LogSpectrogram
from espnet2.utils.get_default_kwargs import get_default_kwargs
from espnet2.utils.nested_dict_action import NestedDictAction
from espnet2.utils.types import int_or_none, str2bool, str_or_none
feats_extractor_choices = ClassChoices(
"feats_extract",
classes=dict(
fbank=LogMelFbank,
log_spectrogram=LogSpectrogram,
linear_spectrogram=LinearSpectrogram,
),
type_check=AbsFeatsExtract,
default="linear_spectrogram",
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default=None,
optional=True,
)
tts_choices = ClassChoices(
"tts",
classes=dict(
vits=VITS,
joint_text2wav=JointText2Wav,
jets=JETS,
),
type_check=AbsGANTTS,
default="vits",
)
pitch_extractor_choices = ClassChoices(
"pitch_extract",
classes=dict(dio=Dio),
type_check=AbsFeatsExtract,
default=None,
optional=True,
)
energy_extractor_choices = ClassChoices(
"energy_extract",
classes=dict(energy=Energy),
type_check=AbsFeatsExtract,
default=None,
optional=True,
)
pitch_normalize_choices = ClassChoices(
"pitch_normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default=None,
optional=True,
)
energy_normalize_choices = ClassChoices(
"energy_normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default=None,
optional=True,
)
[docs]class GANTTSTask(AbsTask):
"""GAN-based text-to-speech task."""
# GAN requires two optimizers
num_optimizers: int = 2
# Add variable objects configurations
class_choices_list = [
# --feats_extractor and --feats_extractor_conf
feats_extractor_choices,
# --normalize and --normalize_conf
normalize_choices,
# --tts and --tts_conf
tts_choices,
# --pitch_extract and --pitch_extract_conf
pitch_extractor_choices,
# --pitch_normalize and --pitch_normalize_conf
pitch_normalize_choices,
# --energy_extract and --energy_extract_conf
energy_extractor_choices,
# --energy_normalize and --energy_normalize_conf
energy_normalize_choices,
]
# Use GANTrainer instead of Trainer
trainer = GANTrainer
[docs] @classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
# NOTE(kamo): Use '_' instead of '-' to avoid confusion
assert check_argument_types()
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
required = parser.get_default("required")
required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--odim",
type=int_or_none,
default=None,
help="The number of dimension of output feature",
)
group.add_argument(
"--model_conf",
action=NestedDictAction,
default=get_default_kwargs(ESPnetGANTTSModel),
help="The keyword arguments for model class.",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default="phn",
choices=["bpe", "char", "word", "phn"],
help="The text will be tokenized in the specified level token",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model file of sentencepiece",
)
parser.add_argument(
"--non_linguistic_symbols",
type=str_or_none,
help="non_linguistic_symbols file path",
)
parser.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
default=None,
help="Apply text cleaning",
)
parser.add_argument(
"--g2p",
type=str_or_none,
choices=g2p_choices,
default=None,
help="Specify g2p method if --token_type=phn",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
[docs] @classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
return CommonCollateFn(
float_pad_value=0.0,
int_pad_value=0,
not_sequence=["spembs", "sids", "lids"],
)
[docs] @classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
non_linguistic_symbols=args.non_linguistic_symbols,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
)
else:
retval = None
assert check_return_type(retval)
return retval
[docs] @classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("text", "speech")
else:
# Inference mode
retval = ("text",)
return retval
[docs] @classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = (
"spembs",
"durations",
"pitch",
"energy",
"sids",
"lids",
)
else:
# Inference mode
retval = (
"spembs",
"speech",
"durations",
"pitch",
"energy",
"sids",
"lids",
)
return retval
[docs] @classmethod
def build_model(cls, args: argparse.Namespace) -> ESPnetGANTTSModel:
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line[0] + line[1:].rstrip() for line in f]
# "args" is saved as it is in a yaml file by BaseTask.main().
# Overwriting token_list to keep it as "portable".
args.token_list = token_list.copy()
elif isinstance(args.token_list, (tuple, list)):
token_list = args.token_list.copy()
else:
raise RuntimeError("token_list must be str or dict")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size }")
# 1. feats_extract
if args.odim is None:
# Extract features in the model
feats_extract_class = feats_extractor_choices.get_class(args.feats_extract)
feats_extract = feats_extract_class(**args.feats_extract_conf)
odim = feats_extract.output_size()
else:
# Give features from data-loader
args.feats_extract = None
args.feats_extract_conf = None
feats_extract = None
odim = args.odim
# 2. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# 3. TTS
tts_class = tts_choices.get_class(args.tts)
tts = tts_class(idim=vocab_size, odim=odim, **args.tts_conf)
# 4. Extra components
pitch_extract = None
energy_extract = None
pitch_normalize = None
energy_normalize = None
if getattr(args, "pitch_extract", None) is not None:
pitch_extract_class = pitch_extractor_choices.get_class(
args.pitch_extract,
)
pitch_extract = pitch_extract_class(
**args.pitch_extract_conf,
)
if getattr(args, "energy_extract", None) is not None:
energy_extract_class = energy_extractor_choices.get_class(
args.energy_extract,
)
energy_extract = energy_extract_class(
**args.energy_extract_conf,
)
if getattr(args, "pitch_normalize", None) is not None:
pitch_normalize_class = pitch_normalize_choices.get_class(
args.pitch_normalize,
)
pitch_normalize = pitch_normalize_class(
**args.pitch_normalize_conf,
)
if getattr(args, "energy_normalize", None) is not None:
energy_normalize_class = energy_normalize_choices.get_class(
args.energy_normalize,
)
energy_normalize = energy_normalize_class(
**args.energy_normalize_conf,
)
# 5. Build model
model = ESPnetGANTTSModel(
feats_extract=feats_extract,
normalize=normalize,
pitch_extract=pitch_extract,
pitch_normalize=pitch_normalize,
energy_extract=energy_extract,
energy_normalize=energy_normalize,
tts=tts,
**args.model_conf,
)
assert check_return_type(model)
return model
[docs] @classmethod
def build_optimizers(
cls,
args: argparse.Namespace,
model: ESPnetGANTTSModel,
) -> List[torch.optim.Optimizer]:
# check
assert hasattr(model.tts, "generator")
assert hasattr(model.tts, "discriminator")
# define generator optimizer
optim_g_class = optim_classes.get(args.optim)
if optim_g_class is None:
raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
if args.sharded_ddp:
try:
import fairscale
except ImportError:
raise RuntimeError("Requiring fairscale. Do 'pip install fairscale'")
optim_g = fairscale.optim.oss.OSS(
params=model.tts.generator.parameters(),
optim=optim_g_class,
**args.optim_conf,
)
else:
optim_g = optim_g_class(
model.tts.generator.parameters(),
**args.optim_conf,
)
optimizers = [optim_g]
# define discriminator optimizer
optim_d_class = optim_classes.get(args.optim2)
if optim_d_class is None:
raise ValueError(f"must be one of {list(optim_classes)}: {args.optim2}")
if args.sharded_ddp:
try:
import fairscale
except ImportError:
raise RuntimeError("Requiring fairscale. Do 'pip install fairscale'")
optim_d = fairscale.optim.oss.OSS(
params=model.tts.discriminator.parameters(),
optim=optim_d_class,
**args.optim2_conf,
)
else:
optim_d = optim_d_class(
model.tts.discriminator.parameters(),
**args.optim2_conf,
)
optimizers += [optim_d]
return optimizers