# Copyright 2023 Yifeng Yu
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Avocodo Modules.
This code is modified from https://github.com/ncsoft/avocodo.
"""
import logging
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
from torch.nn import Conv1d
from torch.nn.utils import spectral_norm, weight_norm
from espnet2.gan_svs.visinger2.visinger2_vocoder import MultiFrequencyDiscriminator
from espnet2.gan_tts.hifigan.residual_block import ResidualBlock
from espnet2.gan_tts.melgan.pqmf import PQMF
[docs]def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
[docs]class AvocodoGenerator(torch.nn.Module):
"""Avocodo generator module."""
def __init__(
self,
in_channels: int = 80,
out_channels: int = 1,
channels: int = 512,
global_channels: int = -1,
kernel_size: int = 7,
upsample_scales: List[int] = [8, 8, 2, 2],
upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
resblock_kernel_sizes: List[int] = [3, 7, 11],
resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
projection_filters: List[int] = [0, 1, 1, 1],
projection_kernels: List[int] = [0, 5, 7, 11],
use_additional_convs: bool = True,
bias: bool = True,
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.2},
use_weight_norm: bool = True,
):
"""Initialize AvocodoGenerator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
channels (int): Number of hidden representation channels.
global_channels (int): Number of global conditioning channels.
kernel_size (int): Kernel size of initial and final conv layer.
upsample_scales (List[int]): List of upsampling scales.
upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers.
resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks.
resblock_dilations (List[List[int]]): List of list of dilations for residual
blocks.
use_additional_convs (bool): Whether to use additional conv layers in
residual blocks.
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
"""
super().__init__()
# check hyperparameters are valid
assert kernel_size % 2 == 1, "Kernel size must be odd number."
assert len(upsample_scales) == len(upsample_kernel_sizes)
assert len(resblock_dilations) == len(resblock_kernel_sizes)
# define modules
self.num_upsamples = len(upsample_kernel_sizes)
self.num_blocks = len(resblock_kernel_sizes)
self.input_conv = torch.nn.Conv1d(
in_channels,
channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
)
self.upsamples = torch.nn.ModuleList()
self.blocks = torch.nn.ModuleList()
self.output_conv = torch.nn.ModuleList()
for i in range(len(upsample_kernel_sizes)):
assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
self.upsamples += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.ConvTranspose1d(
channels // (2**i),
channels // (2 ** (i + 1)),
upsample_kernel_sizes[i],
upsample_scales[i],
padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
output_padding=upsample_scales[i] % 2,
),
)
]
for j in range(len(resblock_kernel_sizes)):
self.blocks += [
ResidualBlock(
kernel_size=resblock_kernel_sizes[j],
channels=channels // (2 ** (i + 1)),
dilations=resblock_dilations[j],
bias=bias,
use_additional_convs=use_additional_convs,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
)
]
if projection_filters[i] != 0:
self.output_conv.append(
torch.nn.Conv1d(
channels // (2 ** (i + 1)),
projection_filters[i],
projection_kernels[i],
1,
padding=projection_kernels[i] // 2,
)
)
else:
self.output_conv.append(torch.nn.Identity())
if global_channels > 0:
self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
# reset parameters
self.reset_parameters()
[docs] def forward(
self, c: torch.Tensor, g: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
c (Tensor): Input tensor (B, in_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
List[Tensor]: List of output tensors (B, out_channels, T).
"""
outs = []
c = self.input_conv(c)
if g is not None:
c = c + self.global_conv(g)
for i in range(self.num_upsamples):
c = self.upsamples[i](c)
cs = 0.0 # initialize
for j in range(self.num_blocks):
cs += self.blocks[i * self.num_blocks + j](c)
c = cs / self.num_blocks
if i >= (self.num_upsamples - 3):
_c = F.leaky_relu(c)
_c = self.output_conv[i](_c)
_c = torch.tanh(_c)
outs.append(_c)
else:
c = self.output_conv[i](c)
return outs
[docs] def reset_parameters(self):
"""Reset parameters.
This initialization follows the official implementation manner.
https://github.com/jik876/hifi-gan/blob/master/models.py
"""
def _reset_parameters(m: torch.nn.Module):
if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
m.weight.data.normal_(0.0, 0.01)
logging.debug(f"Reset parameters in {m}.")
self.apply(_reset_parameters)
[docs] def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m: torch.nn.Module):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
[docs] def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d) or isinstance(
m, torch.nn.ConvTranspose1d
):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
# CoMBD
[docs]class CoMBDBlock(torch.nn.Module):
"""CoMBD (Collaborative Multi-band Discriminator) block module"""
def __init__(
self,
h_u: List[int],
d_k: List[int],
d_s: List[int],
d_d: List[int],
d_g: List[int],
d_p: List[int],
op_f: int,
op_k: int,
op_g: int,
use_spectral_norm=False,
):
super(CoMBDBlock, self).__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = torch.nn.ModuleList()
filters = [[1, h_u[0]]]
for i in range(len(h_u) - 1):
filters.append([h_u[i], h_u[i + 1]])
for _f, _k, _s, _d, _g, _p in zip(filters, d_k, d_s, d_d, d_g, d_p):
self.convs.append(
norm_f(
Conv1d(
in_channels=_f[0],
out_channels=_f[1],
kernel_size=_k,
stride=_s,
dilation=_d,
groups=_g,
padding=_p,
)
)
)
self.projection_conv = norm_f(
Conv1d(
in_channels=filters[-1][1],
out_channels=op_f,
kernel_size=op_k,
groups=op_g,
)
)
[docs] def forward(self, x):
"""
Forward pass through the CoMBD block.
Args:
x (Tensor): Input tensor of shape (B, C_in, T_in).
Returns:
Tuple[Tensor, List[Tensor]]: Tuple containing the output tensor of
shape (B, C_out, T_out)
and a list of feature maps of shape (B, C, T) at each Conv1d layer.
"""
fmap = []
for block in self.convs:
x = block(x)
x = F.leaky_relu(x, 0.2)
fmap.append(x)
x = self.projection_conv(x)
return x, fmap
[docs]class CoMBD(torch.nn.Module):
"""CoMBD (Collaborative Multi-band Discriminator) module
from from https://arxiv.org/abs/2206.13404"""
def __init__(self, h, pqmf_list=None, use_spectral_norm=False):
super(CoMBD, self).__init__()
self.h = h
if pqmf_list is not None:
self.pqmf = pqmf_list
else:
self.pqmf = [PQMF(*h.pqmf_config["lv2"]), PQMF(*h.pqmf_config["lv1"])]
self.blocks = torch.nn.ModuleList()
for _h_u, _d_k, _d_s, _d_d, _d_g, _d_p, _op_f, _op_k, _op_g in zip(
h["combd_h_u"],
h["combd_d_k"],
h["combd_d_s"],
h["combd_d_d"],
h["combd_d_g"],
h["combd_d_p"],
h["combd_op_f"],
h["combd_op_k"],
h["combd_op_g"],
):
self.blocks.append(
CoMBDBlock(
_h_u,
_d_k,
_d_s,
_d_d,
_d_g,
_d_p,
_op_f,
_op_k,
_op_g,
)
)
def _block_forward(self, input, blocks, outs, f_maps):
for x, block in zip(input, blocks):
out, f_map = block(x)
outs.append(out)
f_maps.append(f_map)
return outs, f_maps
def _pqmf_forward(self, ys, ys_hat):
# preprocess for multi_scale forward
multi_scale_inputs = []
multi_scale_inputs_hat = []
for pqmf in self.pqmf:
multi_scale_inputs.append(pqmf.to(ys[-1]).analysis(ys[-1])[:, :1, :])
multi_scale_inputs_hat.append(
pqmf.to(ys[-1]).analysis(ys_hat[-1])[:, :1, :]
)
outs_real = []
f_maps_real = []
# real
# for hierarchical forward
outs_real, f_maps_real = self._block_forward(
ys, self.blocks, outs_real, f_maps_real
)
# for multi_scale forward
outs_real, f_maps_real = self._block_forward(
multi_scale_inputs, self.blocks[:-1], outs_real, f_maps_real
)
outs_fake = []
f_maps_fake = []
# predicted
# for hierarchical forward
outs_fake, f_maps_fake = self._block_forward(
ys_hat, self.blocks, outs_fake, f_maps_fake
)
# for multi_scale forward
outs_fake, f_maps_fake = self._block_forward(
multi_scale_inputs_hat, self.blocks[:-1], outs_fake, f_maps_fake
)
return outs_real, outs_fake, f_maps_real, f_maps_fake
[docs] def forward(self, ys, ys_hat):
"""
Args:
ys (List[Tensor]): List of ground truth signals of shape (B, 1, T).
ys_hat (List[Tensor]): List of predicted signals of shape (B, 1, T).
Returns:
Tuple[List[Tensor], List[Tensor], List[List[Tensor]], List[List[Tensor]]]:
Tuple containing the list of output tensors of shape (B, C_out, T_out)
for real and fake, respectively, and the list of feature maps of shape
(B, C, T) at each Conv1d layer for real and fake, respectively.
"""
outs_real, outs_fake, f_maps_real, f_maps_fake = self._pqmf_forward(ys, ys_hat)
return outs_real, outs_fake, f_maps_real, f_maps_fake
# SBD
[docs]class MDC(torch.nn.Module):
"""Multiscale Dilated Convolution from https://arxiv.org/pdf/1609.07093.pdf"""
def __init__(
self,
in_channels,
out_channels,
strides,
kernel_size,
dilations,
use_spectral_norm=False,
):
super(MDC, self).__init__()
norm_f = weight_norm if not use_spectral_norm else spectral_norm
self.d_convs = torch.nn.ModuleList()
for _k, _d in zip(kernel_size, dilations):
self.d_convs.append(
norm_f(
Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=_k,
dilation=_d,
padding=get_padding(_k, _d),
)
)
)
self.post_conv = norm_f(
Conv1d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=strides,
padding=get_padding(_k, _d),
)
)
self.softmax = torch.nn.Softmax(dim=-1)
[docs] def forward(self, x):
_out = None
for _l in self.d_convs:
_x = torch.unsqueeze(_l(x), -1)
_x = F.leaky_relu(_x, 0.2)
if _out is None:
_out = _x
else:
_out = torch.cat([_out, _x], axis=-1)
x = torch.sum(_out, dim=-1)
x = self.post_conv(x)
x = F.leaky_relu(x, 0.2) # @@
return x
[docs]class SBDBlock(torch.nn.Module):
"""SBD (Sub-band Discriminator) Block"""
def __init__(
self,
segment_dim,
strides,
filters,
kernel_size,
dilations,
use_spectral_norm=False,
):
super(SBDBlock, self).__init__()
norm_f = weight_norm if not use_spectral_norm else spectral_norm
self.convs = torch.nn.ModuleList()
filters_in_out = [(segment_dim, filters[0])]
for i in range(len(filters) - 1):
filters_in_out.append([filters[i], filters[i + 1]])
for _s, _f, _k, _d in zip(strides, filters_in_out, kernel_size, dilations):
self.convs.append(
MDC(
in_channels=_f[0],
out_channels=_f[1],
strides=_s,
kernel_size=_k,
dilations=_d,
use_spectral_norm=use_spectral_norm,
)
)
self.post_conv = norm_f(
Conv1d(
in_channels=_f[1],
out_channels=1,
kernel_size=3,
stride=1,
padding=3 // 2,
)
) # @@
[docs] def forward(self, x):
fmap = []
for _l in self.convs:
x = _l(x)
fmap.append(x)
x = self.post_conv(x) # @@
return x, fmap
[docs]class MDCDConfig:
def __init__(self, h):
self.pqmf_params = h["pqmf_config"]["sbd"]
self.f_pqmf_params = h["pqmf_config"]["fsbd"]
self.filters = h["sbd_filters"]
self.kernel_sizes = h["sbd_kernel_sizes"]
self.dilations = h["sbd_dilations"]
self.strides = h["sbd_strides"]
self.band_ranges = h["sbd_band_ranges"]
self.transpose = h["sbd_transpose"]
self.segment_size = h["segment_size"]
[docs]class SBD(torch.nn.Module):
"""SBD (Sub-band Discriminator) from https://arxiv.org/pdf/2206.13404.pdf"""
def __init__(self, h, use_spectral_norm=False):
super(SBD, self).__init__()
self.config = MDCDConfig(h)
self.pqmf = PQMF(*self.config.pqmf_params)
if True in h["sbd_transpose"]:
self.f_pqmf = PQMF(*self.config.f_pqmf_params)
else:
self.f_pqmf = None
self.discriminators = torch.nn.ModuleList()
for _f, _k, _d, _s, _br, _tr in zip(
self.config.filters,
self.config.kernel_sizes,
self.config.dilations,
self.config.strides,
self.config.band_ranges,
self.config.transpose,
):
if _tr:
segment_dim = self.config.segment_size // _br[1] - _br[0]
else:
segment_dim = _br[1] - _br[0]
self.discriminators.append(
SBDBlock(
segment_dim=segment_dim,
filters=_f,
kernel_size=_k,
dilations=_d,
strides=_s,
use_spectral_norm=use_spectral_norm,
)
)
[docs] def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
y_in = self.pqmf.analysis(y)
y_hat_in = self.pqmf.analysis(y_hat)
if self.f_pqmf is not None:
y_in_f = self.f_pqmf.analysis(y)
y_hat_in_f = self.f_pqmf.analysis(y_hat)
for d, br, tr in zip(
self.discriminators, self.config.band_ranges, self.config.transpose
):
if tr:
_y_in = y_in_f[:, br[0] : br[1], :]
_y_hat_in = y_hat_in_f[:, br[0] : br[1], :]
_y_in = torch.transpose(_y_in, 1, 2)
_y_hat_in = torch.transpose(_y_hat_in, 1, 2)
else:
_y_in = y_in[:, br[0] : br[1], :]
_y_hat_in = y_hat_in[:, br[0] : br[1], :]
y_d_r, fmap_r = d(_y_in)
y_d_g, fmap_g = d(_y_hat_in)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
[docs]class AvocodoDiscriminator(torch.nn.Module):
"""Avocodo Discriminator module"""
def __init__(
self,
combd: Dict[str, Any] = {
"combd_h_u": [
[16, 64, 256, 1024, 1024, 1024],
[16, 64, 256, 1024, 1024, 1024],
[16, 64, 256, 1024, 1024, 1024],
],
"combd_d_k": [
[7, 11, 11, 11, 11, 5],
[11, 21, 21, 21, 21, 5],
[15, 41, 41, 41, 41, 5],
],
"combd_d_s": [
[1, 1, 4, 4, 4, 1],
[1, 1, 4, 4, 4, 1],
[1, 1, 4, 4, 4, 1],
],
"combd_d_d": [
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
],
"combd_d_g": [
[1, 4, 16, 64, 256, 1],
[1, 4, 16, 64, 256, 1],
[1, 4, 16, 64, 256, 1],
],
"combd_d_p": [
[3, 5, 5, 5, 5, 2],
[5, 10, 10, 10, 10, 2],
[7, 20, 20, 20, 20, 2],
],
"combd_op_f": [1, 1, 1],
"combd_op_k": [3, 3, 3],
"combd_op_g": [1, 1, 1],
},
sbd: Dict[str, Any] = {
"use_sbd": True,
"sbd_filters": [
[64, 128, 256, 256, 256],
[64, 128, 256, 256, 256],
[64, 128, 256, 256, 256],
[32, 64, 128, 128, 128],
],
"sbd_strides": [
[1, 1, 3, 3, 1],
[1, 1, 3, 3, 1],
[1, 1, 3, 3, 1],
[1, 1, 3, 3, 1],
],
"sbd_kernel_sizes": [
[[7, 7, 7], [7, 7, 7], [7, 7, 7], [7, 7, 7], [7, 7, 7]],
[[5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5]],
[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]],
[[5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5]],
],
"sbd_dilations": [
[[5, 7, 11], [5, 7, 11], [5, 7, 11], [5, 7, 11], [5, 7, 11]],
[[3, 5, 7], [3, 5, 7], [3, 5, 7], [3, 5, 7], [3, 5, 7]],
[[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]],
[[1, 2, 3], [1, 2, 3], [1, 2, 3], [2, 3, 5], [2, 3, 5]],
],
"sbd_band_ranges": [[0, 6], [0, 11], [0, 16], [0, 64]],
"sbd_transpose": [False, False, False, True],
"pqmf_config": {
"sbd": [16, 256, 0.03, 10.0],
"fsbd": [64, 256, 0.1, 9.0],
},
"segment_size": 8192,
},
pqmf_config: Dict[str, Any] = {
"lv1": [2, 256, 0.25, 10.0],
"lv2": [4, 192, 0.13, 10.0],
},
projection_filters: List[int] = [0, 1, 1, 1],
):
super(AvocodoDiscriminator, self).__init__()
self.pqmf_lv2 = PQMF(*pqmf_config["lv2"])
self.pqmf_lv1 = PQMF(*pqmf_config["lv1"])
self.combd = CoMBD(
combd,
[self.pqmf_lv2, self.pqmf_lv1],
use_spectral_norm=combd["use_spectral_norm"],
)
self.sbd = SBD(
sbd,
use_spectral_norm=sbd["use_spectral_norm"],
)
self.projection_filters = projection_filters
[docs] def forward(
self, y: torch.Tensor, y_hats: torch.Tensor
) -> List[List[torch.Tensor]]:
ys = [
self.pqmf_lv2.analysis(y)[:, : self.projection_filters[1]],
self.pqmf_lv1.analysis(y)[:, : self.projection_filters[2]],
y,
]
(
combd_outs_real,
combd_outs_fake,
combd_fmaps_real,
combd_fmaps_fake,
) = self.combd(ys, y_hats)
sbd_outs_real, sbd_outs_fake, sbd_fmaps_real, sbd_fmaps_fake = self.sbd(
y, y_hats[-1]
)
# Combine the outputs of both discriminators
outs_real = combd_outs_real + sbd_outs_real
outs_fake = combd_outs_fake + sbd_outs_fake
fmaps_real = combd_fmaps_real + sbd_fmaps_real
fmaps_fake = combd_fmaps_fake + sbd_fmaps_fake
return outs_real, outs_fake, fmaps_real, fmaps_fake
[docs]class AvocodoDiscriminatorPlus(torch.nn.Module):
"""Avocodo discriminator with additional MFD."""
def __init__(
self,
combd: Dict[str, Any] = {
"combd_h_u": [
[16, 64, 256, 1024, 1024, 1024],
[16, 64, 256, 1024, 1024, 1024],
[16, 64, 256, 1024, 1024, 1024],
],
"combd_d_k": [
[7, 11, 11, 11, 11, 5],
[11, 21, 21, 21, 21, 5],
[15, 41, 41, 41, 41, 5],
],
"combd_d_s": [
[1, 1, 4, 4, 4, 1],
[1, 1, 4, 4, 4, 1],
[1, 1, 4, 4, 4, 1],
],
"combd_d_d": [
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
],
"combd_d_g": [
[1, 4, 16, 64, 256, 1],
[1, 4, 16, 64, 256, 1],
[1, 4, 16, 64, 256, 1],
],
"combd_d_p": [
[3, 5, 5, 5, 5, 2],
[5, 10, 10, 10, 10, 2],
[7, 20, 20, 20, 20, 2],
],
"combd_op_f": [1, 1, 1],
"combd_op_k": [3, 3, 3],
"combd_op_g": [1, 1, 1],
},
sbd: Dict[str, Any] = {
"use_sbd": True,
"sbd_filters": [
[64, 128, 256, 256, 256],
[64, 128, 256, 256, 256],
[64, 128, 256, 256, 256],
[32, 64, 128, 128, 128],
],
"sbd_strides": [
[1, 1, 3, 3, 1],
[1, 1, 3, 3, 1],
[1, 1, 3, 3, 1],
[1, 1, 3, 3, 1],
],
"sbd_kernel_sizes": [
[[7, 7, 7], [7, 7, 7], [7, 7, 7], [7, 7, 7], [7, 7, 7]],
[[5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5]],
[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]],
[[5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5]],
],
"sbd_dilations": [
[[5, 7, 11], [5, 7, 11], [5, 7, 11], [5, 7, 11], [5, 7, 11]],
[[3, 5, 7], [3, 5, 7], [3, 5, 7], [3, 5, 7], [3, 5, 7]],
[[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]],
[[1, 2, 3], [1, 2, 3], [1, 2, 3], [2, 3, 5], [2, 3, 5]],
],
"sbd_band_ranges": [[0, 6], [0, 11], [0, 16], [0, 64]],
"sbd_transpose": [False, False, False, True],
"pqmf_config": {
"sbd": [16, 256, 0.03, 10.0],
"fsbd": [64, 256, 0.1, 9.0],
},
"segment_size": 8192,
},
pqmf_config: Dict[str, Any] = {
"lv1": [2, 256, 0.25, 10.0],
"lv2": [4, 192, 0.13, 10.0],
},
projection_filters: List[int] = [0, 1, 1, 1],
# Multi-frequency discriminator related
sample_rate: int = 22050,
multi_freq_disc_params: Dict[str, Any] = {
"hop_length_factors": [4, 8, 16],
"hidden_channels": [256, 512, 512],
"domain": "double",
"mel_scale": True,
"divisors": [32, 16, 8, 4, 2, 1, 1],
"strides": [1, 2, 1, 2, 1, 2, 1],
},
):
super().__init__()
self.pqmf_lv2 = PQMF(*pqmf_config["lv2"])
self.pqmf_lv1 = PQMF(*pqmf_config["lv1"])
self.combd = CoMBD(
combd,
[self.pqmf_lv2, self.pqmf_lv1],
use_spectral_norm=combd["use_spectral_norm"],
)
self.sbd = SBD(
sbd,
use_spectral_norm=sbd["use_spectral_norm"],
)
# Multi-frequency discriminator related
if "hop_lengths" not in multi_freq_disc_params:
# Transfer hop lengths factors to hop lengths
multi_freq_disc_params["hop_lengths"] = []
for i in range(len(multi_freq_disc_params["hop_length_factors"])):
multi_freq_disc_params["hop_lengths"].append(
int(
sample_rate
* multi_freq_disc_params["hop_length_factors"][i]
/ 1000
)
)
del multi_freq_disc_params["hop_length_factors"]
self.mfd = MultiFrequencyDiscriminator(
**multi_freq_disc_params,
)
self.projection_filters = projection_filters
[docs] def forward(
self, y: torch.Tensor, y_hats: torch.Tensor
) -> List[List[torch.Tensor]]:
ys = [
self.pqmf_lv2.analysis(y)[:, : self.projection_filters[1]],
self.pqmf_lv1.analysis(y)[:, : self.projection_filters[2]],
y,
]
(
combd_outs_real,
combd_outs_fake,
combd_fmaps_real,
combd_fmaps_fake,
) = self.combd(ys, y_hats)
sbd_outs_real, sbd_outs_fake, sbd_fmaps_real, sbd_fmaps_fake = self.sbd(
y, y_hats[-1]
)
mfd_fmaps_real = self.mfd(y)
mfd_fmaps_fake = self.mfd(y_hats[-1])
mfd_outs_real = mfd_fmaps_real[-1]
mfd_outs_fake = mfd_fmaps_fake[-1]
# Combine the outputs of both discriminators
outs_real = combd_outs_real + sbd_outs_real + mfd_outs_real
outs_fake = combd_outs_fake + sbd_outs_fake + mfd_outs_fake
fmaps_real = combd_fmaps_real + sbd_fmaps_real + mfd_fmaps_real
fmaps_fake = combd_fmaps_fake + sbd_fmaps_fake + mfd_fmaps_fake
return outs_real, outs_fake, fmaps_real, fmaps_fake