import argparse
import copy
import os
from typing import Callable, Collection, Dict, List, Optional, Tuple
import numpy as np
import torch
from typeguard import check_argument_types, check_return_type
from espnet2.diar.layers.abs_mask import AbsMask
from espnet2.diar.layers.multi_mask import MultiMask
from espnet2.diar.separator.tcn_separator_nomask import TCNSeparatorNomask
from espnet2.enh.decoder.abs_decoder import AbsDecoder
from espnet2.enh.decoder.conv_decoder import ConvDecoder
from espnet2.enh.decoder.null_decoder import NullDecoder
from espnet2.enh.decoder.stft_decoder import STFTDecoder
from espnet2.enh.encoder.abs_encoder import AbsEncoder
from espnet2.enh.encoder.conv_encoder import ConvEncoder
from espnet2.enh.encoder.null_encoder import NullEncoder
from espnet2.enh.encoder.stft_encoder import STFTEncoder
from espnet2.enh.espnet_model import ESPnetEnhancementModel
from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss
from espnet2.enh.loss.criterions.tf_domain import (
FrequencyDomainAbsCoherence,
FrequencyDomainDPCL,
FrequencyDomainL1,
FrequencyDomainMSE,
)
from espnet2.enh.loss.criterions.time_domain import (
CISDRLoss,
MultiResL1SpecLoss,
SDRLoss,
SISNRLoss,
SNRLoss,
TimeDomainL1,
TimeDomainMSE,
)
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper
from espnet2.enh.loss.wrappers.dpcl_solver import DPCLSolver
from espnet2.enh.loss.wrappers.fixed_order import FixedOrderSolver
from espnet2.enh.loss.wrappers.mixit_solver import MixITSolver
from espnet2.enh.loss.wrappers.multilayer_pit_solver import MultiLayerPITSolver
from espnet2.enh.loss.wrappers.pit_solver import PITSolver
from espnet2.enh.separator.abs_separator import AbsSeparator
from espnet2.enh.separator.asteroid_models import AsteroidModel_Converter
from espnet2.enh.separator.conformer_separator import ConformerSeparator
from espnet2.enh.separator.dan_separator import DANSeparator
from espnet2.enh.separator.dc_crn_separator import DC_CRNSeparator
from espnet2.enh.separator.dccrn_separator import DCCRNSeparator
from espnet2.enh.separator.dpcl_e2e_separator import DPCLE2ESeparator
from espnet2.enh.separator.dpcl_separator import DPCLSeparator
from espnet2.enh.separator.dprnn_separator import DPRNNSeparator
from espnet2.enh.separator.dptnet_separator import DPTNetSeparator
from espnet2.enh.separator.fasnet_separator import FaSNetSeparator
from espnet2.enh.separator.ineube_separator import iNeuBe
from espnet2.enh.separator.neural_beamformer import NeuralBeamformer
from espnet2.enh.separator.rnn_separator import RNNSeparator
from espnet2.enh.separator.skim_separator import SkiMSeparator
from espnet2.enh.separator.svoice_separator import SVoiceSeparator
from espnet2.enh.separator.tcn_separator import TCNSeparator
from espnet2.enh.separator.tfgridnet_separator import TFGridNet
from espnet2.enh.separator.transformer_separator import TransformerSeparator
from espnet2.iterators.abs_iter_factory import AbsIterFactory
from espnet2.tasks.abs_task import AbsTask
from espnet2.torch_utils.initialize import initialize
from espnet2.train.class_choices import ClassChoices
from espnet2.train.collate_fn import CommonCollateFn
from espnet2.train.distributed_utils import DistributedOption
from espnet2.train.preprocessor import (
AbsPreprocessor,
DynamicMixingPreprocessor,
EnhPreprocessor,
)
from espnet2.train.trainer import Trainer
from espnet2.utils.get_default_kwargs import get_default_kwargs
from espnet2.utils.nested_dict_action import NestedDictAction
from espnet2.utils.types import str2bool, str_or_none
encoder_choices = ClassChoices(
name="encoder",
classes=dict(stft=STFTEncoder, conv=ConvEncoder, same=NullEncoder),
type_check=AbsEncoder,
default="stft",
)
separator_choices = ClassChoices(
name="separator",
classes=dict(
asteroid=AsteroidModel_Converter,
conformer=ConformerSeparator,
dan=DANSeparator,
dc_crn=DC_CRNSeparator,
dccrn=DCCRNSeparator,
dpcl=DPCLSeparator,
dpcl_e2e=DPCLE2ESeparator,
dprnn=DPRNNSeparator,
dptnet=DPTNetSeparator,
fasnet=FaSNetSeparator,
rnn=RNNSeparator,
skim=SkiMSeparator,
svoice=SVoiceSeparator,
tcn=TCNSeparator,
transformer=TransformerSeparator,
wpe_beamformer=NeuralBeamformer,
tcn_nomask=TCNSeparatorNomask,
ineube=iNeuBe,
tfgridnet=TFGridNet,
),
type_check=AbsSeparator,
default="rnn",
)
mask_module_choices = ClassChoices(
name="mask_module",
classes=dict(multi_mask=MultiMask),
type_check=AbsMask,
default="multi_mask",
)
decoder_choices = ClassChoices(
name="decoder",
classes=dict(stft=STFTDecoder, conv=ConvDecoder, same=NullDecoder),
type_check=AbsDecoder,
default="stft",
)
loss_wrapper_choices = ClassChoices(
name="loss_wrappers",
classes=dict(
pit=PITSolver,
fixed_order=FixedOrderSolver,
multilayer_pit=MultiLayerPITSolver,
dpcl=DPCLSolver,
mixit=MixITSolver,
),
type_check=AbsLossWrapper,
default=None,
)
criterion_choices = ClassChoices(
name="criterions",
classes=dict(
ci_sdr=CISDRLoss,
coh=FrequencyDomainAbsCoherence,
sdr=SDRLoss,
si_snr=SISNRLoss,
snr=SNRLoss,
l1=FrequencyDomainL1,
dpcl=FrequencyDomainDPCL,
l1_fd=FrequencyDomainL1,
l1_td=TimeDomainL1,
mse=FrequencyDomainMSE,
mse_fd=FrequencyDomainMSE,
mse_td=TimeDomainMSE,
mr_l1_tfd=MultiResL1SpecLoss,
),
type_check=AbsEnhLoss,
default=None,
)
preprocessor_choices = ClassChoices(
name="preprocessor",
classes=dict(
dynamic_mixing=DynamicMixingPreprocessor,
enh=EnhPreprocessor,
),
type_check=AbsPreprocessor,
default=None,
)
MAX_REFERENCE_NUM = 100
[docs]class EnhancementTask(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
class_choices_list = [
# --encoder and --encoder_conf
encoder_choices,
# --separator and --separator_conf
separator_choices,
# --decoder and --decoder_conf
decoder_choices,
# --mask_module and --mask_module_conf
mask_module_choices,
# --preprocessor and --preprocessor_conf
preprocessor_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
[docs] @classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
# required = parser.get_default("required")
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--model_conf",
action=NestedDictAction,
default=get_default_kwargs(ESPnetEnhancementModel),
help="The keyword arguments for model class.",
)
group.add_argument(
"--criterions",
action=NestedDictAction,
default=[
{
"name": "si_snr",
"conf": {},
"wrapper": "fixed_order",
"wrapper_conf": {},
},
],
help="The criterions binded with the loss wrappers.",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--speech_volume_normalize",
type=str_or_none,
default=None,
help="Scale the maximum amplitude to the given value or range. "
"e.g. --speech_volume_normalize 1.0 scales it to 1.0.\n"
"--speech_volume_normalize 0.5_1.0 scales it to a random number in "
"the range [0.5, 1.0)",
)
group.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The file path of rir scp file.",
)
group.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="THe probability for applying RIR convolution.",
)
group.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
group.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability applying Noise adding.",
)
group.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of signal-to-noise ratio (SNR) level in decibel.",
)
group.add_argument(
"--short_noise_thres",
type=float,
default=0.5,
help="If len(noise) / len(speech) is smaller than this threshold during "
"dynamic mixing, a warning will be displayed.",
)
group.add_argument(
"--use_reverberant_ref",
type=str2bool,
default=False,
help="Whether to use reverberant speech references "
"instead of anechoic ones",
)
group.add_argument(
"--num_spk",
type=int,
default=1,
help="Number of speakers in the input signal.",
)
group.add_argument(
"--num_noise_type",
type=int,
default=1,
help="Number of noise types.",
)
group.add_argument(
"--sample_rate",
type=int,
default=8000,
help="Sampling rate of the data (in Hz).",
)
group.add_argument(
"--force_single_channel",
type=str2bool,
default=False,
help="Whether to force all data to be single-channel.",
)
group.add_argument(
"--channel_reordering",
type=str2bool,
default=False,
help="Whether to randomly reorder the channels of the "
"multi-channel signals.",
)
group.add_argument(
"--categories",
nargs="+",
default=[],
type=str,
help="The set of all possible categories in the dataset. Used to add the "
"category information to each sample",
)
group.add_argument(
"--dynamic_mixing",
type=str2bool,
default=False,
help="Apply dynamic mixing",
)
group.add_argument(
"--utt2spk",
type=str_or_none,
default=None,
help="The file path of utt2spk file. Only used in dynamic_mixing mode.",
)
group.add_argument(
"--dynamic_mixing_gain_db",
type=float,
default=0.0,
help="Random gain (in dB) for dynamic mixing sources",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
[docs] @classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
[docs] @classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
use_preprocessor = getattr(args, "preprocessor", None) is not None
if use_preprocessor:
# TODO(simpleoier): To make this as simple as model parts, e.g. encoder
if args.preprocessor == "dynamic_mixing":
retval = preprocessor_choices.get_class(args.preprocessor)(
train=train,
source_scp=os.path.join(
os.path.dirname(args.train_data_path_and_name_and_type[0][0]),
args.preprocessor_conf.get("source_scp_name", "spk1.scp"),
),
ref_num=args.preprocessor_conf.get(
"ref_num", args.separator_conf["num_spk"]
),
dynamic_mixing_gain_db=args.preprocessor_conf.get(
"dynamic_mixing_gain_db", 0.0
),
speech_name=args.preprocessor_conf.get("speech_name", "speech_mix"),
speech_ref_name_prefix=args.preprocessor_conf.get(
"speech_ref_name_prefix", "speech_ref"
),
mixture_source_name=args.preprocessor_conf.get(
"mixture_source_name", None
),
utt2spk=getattr(args, "utt2spk", None),
categories=args.preprocessor_conf.get("categories", None),
)
elif args.preprocessor == "enh":
retval = preprocessor_choices.get_class(args.preprocessor)(
train=train,
# NOTE(kamo): Check attribute existence for backward compatibility
rir_scp=getattr(args, "rir_scp", None),
rir_apply_prob=getattr(args, "rir_apply_prob", 1.0),
noise_scp=getattr(args, "noise_scp", None),
noise_apply_prob=getattr(args, "noise_apply_prob", 1.0),
noise_db_range=getattr(args, "noise_db_range", "13_15"),
short_noise_thres=getattr(args, "short_noise_thres", 0.5),
speech_volume_normalize=getattr(
args, "speech_volume_normalize", None
),
use_reverberant_ref=getattr(args, "use_reverberant_ref", None),
num_spk=getattr(args, "num_spk", 1),
num_noise_type=getattr(args, "num_noise_type", 1),
sample_rate=getattr(args, "sample_rate", 8000),
force_single_channel=getattr(args, "force_single_channel", False),
channel_reordering=getattr(args, "channel_reordering", False),
categories=getattr(args, "categories", None),
)
else:
raise ValueError(
f"Preprocessor type {args.preprocessor} is not supported."
)
else:
retval = None
assert check_return_type(retval)
return retval
[docs] @classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech_ref1",)
else:
# Inference mode
retval = ("speech_mix",)
return retval
[docs] @classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ["speech_mix"]
retval += ["dereverb_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval += ["speech_ref{}".format(n) for n in range(2, MAX_REFERENCE_NUM + 1)]
retval += ["noise_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval += ["category"]
retval = tuple(retval)
assert check_return_type(retval)
return retval
[docs] @classmethod
def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel:
assert check_argument_types()
encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf)
separator = separator_choices.get_class(args.separator)(
encoder.output_dim, **args.separator_conf
)
decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf)
if args.separator.endswith("nomask"):
mask_module = mask_module_choices.get_class(args.mask_module)(
input_dim=encoder.output_dim,
**args.mask_module_conf,
)
else:
mask_module = None
loss_wrappers = []
if getattr(args, "criterions", None) is not None:
# This check is for the compatibility when load models
# that packed by older version
for ctr in args.criterions:
criterion_conf = ctr.get("conf", {})
criterion = criterion_choices.get_class(ctr["name"])(**criterion_conf)
loss_wrapper = loss_wrapper_choices.get_class(ctr["wrapper"])(
criterion=criterion, **ctr["wrapper_conf"]
)
loss_wrappers.append(loss_wrapper)
# 1. Build model
model = ESPnetEnhancementModel(
encoder=encoder,
separator=separator,
decoder=decoder,
loss_wrappers=loss_wrappers,
mask_module=mask_module,
**args.model_conf,
)
# FIXME(kamo): Should be done in model?
# 2. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model
[docs] @classmethod
def build_iter_factory(
cls,
args: argparse.Namespace,
distributed_option: DistributedOption,
mode: str,
kwargs: dict = None,
) -> AbsIterFactory:
dynamic_mixing = getattr(args, "dynamic_mixing", False)
if dynamic_mixing and mode == "train":
args = copy.deepcopy(args)
args.fold_length = args.fold_length[0:1]
return super().build_iter_factory(args, distributed_option, mode, kwargs)