from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor
from espnet2.enh.layers.complexnn import (
ComplexBatchNorm,
ComplexConv2d,
ComplexConvTranspose2d,
NavieComplexLSTM,
complex_cat,
)
from espnet2.enh.separator.abs_separator import AbsSeparator
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
EPS = torch.finfo(torch.double).eps
[docs]class DCCRNSeparator(AbsSeparator):
def __init__(
self,
input_dim: int,
num_spk: int = 1,
rnn_layer: int = 2,
rnn_units: int = 256,
masking_mode: str = "E",
use_clstm: bool = True,
bidirectional: bool = False,
use_cbn: bool = False,
kernel_size: int = 5,
kernel_num: List[int] = [32, 64, 128, 256, 256, 256],
use_builtin_complex: bool = True,
use_noise_mask: bool = False,
):
"""DCCRN separator.
Args:
input_dim (int): input dimension。
num_spk (int, optional): number of speakers. Defaults to 1.
rnn_layer (int, optional): number of lstm layers in the crn. Defaults to 2.
rnn_units (int, optional): rnn units. Defaults to 128.
masking_mode (str, optional): usage of the estimated mask. Defaults to "E".
use_clstm (bool, optional): whether use complex LSTM. Defaults to False.
bidirectional (bool, optional): whether use BLSTM. Defaults to False.
use_cbn (bool, optional): whether use complex BN. Defaults to False.
kernel_size (int, optional): convolution kernel size. Defaults to 5.
kernel_num (list, optional): output dimension of each layer of the encoder.
use_builtin_complex (bool, optional): torch.complex if True,
else ComplexTensor.
use_noise_mask (bool, optional): whether to estimate the mask of noise.
"""
super().__init__()
self.use_builtin_complex = use_builtin_complex
self._num_spk = num_spk
self.use_noise_mask = use_noise_mask
self.predict_noise = use_noise_mask
if masking_mode not in ["C", "E", "R"]:
raise ValueError("Unsupported masking mode: %s" % masking_mode)
# Network config
self.rnn_units = rnn_units
self.hidden_layers = rnn_layer
self.kernel_size = kernel_size
self.kernel_num = [2] + kernel_num
self.masking_mode = masking_mode
self.use_clstm = use_clstm
fac = 2 if bidirectional else 1
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for idx in range(len(self.kernel_num) - 1):
self.encoder.append(
nn.Sequential(
ComplexConv2d(
self.kernel_num[idx],
self.kernel_num[idx + 1],
kernel_size=(self.kernel_size, 2),
stride=(2, 1),
padding=(2, 1),
),
nn.BatchNorm2d(self.kernel_num[idx + 1])
if not use_cbn
else ComplexBatchNorm(self.kernel_num[idx + 1]),
nn.PReLU(),
)
)
hidden_dim = (input_dim - 1 + 2 ** (len(self.kernel_num) - 1) - 1) // (
2 ** (len(self.kernel_num) - 1)
)
hidden_dim = hidden_dim if hidden_dim > 0 else 1
if self.use_clstm:
rnns = []
for idx in range(rnn_layer):
rnns.append(
NavieComplexLSTM(
input_size=hidden_dim * self.kernel_num[-1]
if idx == 0
else self.rnn_units * fac,
hidden_size=self.rnn_units,
bidirectional=bidirectional,
batch_first=False,
projection_dim=hidden_dim * self.kernel_num[-1]
if idx == rnn_layer - 1
else None,
)
)
self.enhance = nn.Sequential(*rnns)
else:
self.enhance = nn.LSTM(
input_size=hidden_dim * self.kernel_num[-1],
hidden_size=self.rnn_units,
num_layers=2,
dropout=0.0,
bidirectional=bidirectional,
batch_first=False,
)
self.tranform = nn.Linear(
self.rnn_units * fac, hidden_dim * self.kernel_num[-1]
)
for idx in range(len(self.kernel_num) - 1, 0, -1):
if idx != 1:
self.decoder.append(
nn.Sequential(
ComplexConvTranspose2d(
self.kernel_num[idx] * 2,
self.kernel_num[idx - 1],
kernel_size=(self.kernel_size, 2),
stride=(2, 1),
padding=(2, 0),
output_padding=(1, 0),
),
nn.BatchNorm2d(self.kernel_num[idx - 1])
if not use_cbn
else ComplexBatchNorm(self.kernel_num[idx - 1]),
nn.PReLU(),
)
)
else:
self.decoder.append(
nn.Sequential(
ComplexConvTranspose2d(
self.kernel_num[idx] * 2,
self.kernel_num[idx - 1] * (self._num_spk + 1)
if self.use_noise_mask
else self.kernel_num[idx - 1] * self._num_spk,
kernel_size=(self.kernel_size, 2),
stride=(2, 1),
padding=(2, 0),
output_padding=(1, 0),
),
)
)
self.flatten_parameters()
[docs] def forward(
self,
input: Union[torch.Tensor, ComplexTensor],
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
"""Forward.
Args:
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, F]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, F), ...]
ilens (torch.Tensor): (B,)
others predicted data, e.g. masks: OrderedDict[
'mask_spk1': torch.Tensor(Batch, Frames, Freq),
'mask_spk2': torch.Tensor(Batch, Frames, Freq),
...
'mask_spkn': torch.Tensor(Batch, Frames, Freq),
]
"""
# shape (B, T, F) --> (B, F, T)
specs = input.permute(0, 2, 1)
real, imag = specs.real, specs.imag
# # shape (B, F, T)
# spec_mags = torch.sqrt(real**2 + imag**2 + 1e-8)
# # shape (B, F, T)
# spec_phase = torch.atan2(imag, real)
# shape (B, 2, F, T)
cspecs = torch.stack([real, imag], 1)
# shape (B, 2, F-1, T)
cspecs = cspecs[:, :, 1:]
out = cspecs
encoder_out = []
for idx, layer in enumerate(self.encoder):
out = layer(out)
encoder_out.append(out)
# shape (B, C, F, T)
batch_size, channels, dims, lengths = out.size()
# shape (T, B, C, F)
out = out.permute(3, 0, 1, 2)
if self.use_clstm:
# shape (T, B, C // 2, F)
r_rnn_in = out[:, :, : channels // 2]
# shape (T, B, C // 2, F)
i_rnn_in = out[:, :, channels // 2 :]
# shape (T, B, C // 2 * F)
r_rnn_in = torch.reshape(
r_rnn_in, [lengths, batch_size, channels // 2 * dims]
)
# shape (T, B, C // 2 * F)
i_rnn_in = torch.reshape(
i_rnn_in, [lengths, batch_size, channels // 2 * dims]
)
r_rnn_in, i_rnn_in = self.enhance([r_rnn_in, i_rnn_in])
# shape (T, B, C // 2, F)
r_rnn_in = torch.reshape(
r_rnn_in, [lengths, batch_size, channels // 2, dims]
)
# shape (T, B, C // 2, F)
i_rnn_in = torch.reshape(
i_rnn_in, [lengths, batch_size, channels // 2, dims]
)
# shape (T, B, C, F)
out = torch.cat([r_rnn_in, i_rnn_in], 2)
else:
# shape (T, B, C*F)
out = torch.reshape(out, [lengths, batch_size, channels * dims])
out, _ = self.enhance(out)
out = self.tranform(out)
# shape (T, B, C, F)
out = torch.reshape(out, [lengths, batch_size, channels, dims])
# shape (B, C, F, T)
out = out.permute(1, 2, 3, 0)
for idx in range(len(self.decoder)):
# skip connection
out = complex_cat([out, encoder_out[-1 - idx]], 1)
out = self.decoder[idx](out)
out = out[..., 1:]
# out shape = (B, 2*num_spk, F-1, T) if self.use_noise_mask == False
# else (B, 2*(num_spk+1), F-1, T)
masks = self.create_masks(out)
masked = self.apply_masks(masks, real, imag)
others = OrderedDict(
zip(
["mask_spk{}".format(i + 1) for i in range(self.num_spk)],
masks,
)
)
if self.use_noise_mask:
others["mask_noise1"] = masks[-1]
others["noise1"] = masked.pop(-1)
return (masked, ilens, others)
[docs] def flatten_parameters(self):
if isinstance(self.enhance, nn.LSTM):
self.enhance.flatten_parameters()
[docs] def create_masks(self, mask_tensor: torch.Tensor):
"""create estimated mask for each speaker
Args:
mask_tensor (torch.Tensor): output of decoder, shape(B, 2*num_spk, F-1, T)
"""
if self.use_noise_mask:
assert mask_tensor.shape[1] == 2 * (self._num_spk + 1), mask_tensor.shape[1]
else:
assert mask_tensor.shape[1] == 2 * self._num_spk, mask_tensor.shape[1]
masks = []
for idx in range(mask_tensor.shape[1] // 2):
# shape (B, F-1, T)
mask_real = mask_tensor[:, idx * 2]
# shape (B, F-1, T)
mask_imag = mask_tensor[:, idx * 2 + 1]
# shape (B, F, T)
mask_real = F.pad(mask_real, [0, 0, 1, 0])
# shape (B, F, T)
mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
# mask shape (B, T, F)
if is_torch_1_9_plus and self.use_builtin_complex:
complex_mask = torch.complex(
mask_real.permute(0, 2, 1), mask_imag.permute(0, 2, 1)
)
else:
complex_mask = ComplexTensor(
mask_real.permute(0, 2, 1), mask_imag.permute(0, 2, 1)
)
masks.append(complex_mask)
return masks
[docs] def apply_masks(
self,
masks: List[Union[torch.Tensor, ComplexTensor]],
real: torch.Tensor,
imag: torch.Tensor,
):
"""apply masks
Args:
masks : est_masks, [(B, T, F), ...]
real (torch.Tensor): real part of the noisy spectrum, (B, F, T)
imag (torch.Tensor): imag part of the noisy spectrum, (B, F, T)
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, F), ...]
"""
masked = []
for i in range(len(masks)):
# shape (B, T, F) --> (B, F, T)
mask_real = masks[i].real.permute(0, 2, 1)
mask_imag = masks[i].imag.permute(0, 2, 1)
if self.masking_mode == "E":
# shape (B, F, T)
spec_mags = torch.sqrt(real**2 + imag**2 + 1e-8)
# shape (B, F, T)
spec_phase = torch.atan2(imag, real)
mask_mags = (mask_real**2 + mask_imag**2) ** 0.5
# mask_mags = (mask_real ** 2 + mask_imag ** 2 + EPS) ** 0.5
real_phase = mask_real / (mask_mags + EPS)
imag_phase = mask_imag / (mask_mags + EPS)
# mask_phase = torch.atan2(imag_phase + EPS, real_phase + EPS)
mask_phase = torch.atan2(imag_phase, real_phase)
mask_mags = torch.tanh(mask_mags)
est_mags = mask_mags * spec_mags
est_phase = spec_phase + mask_phase
real = est_mags * torch.cos(est_phase)
imag = est_mags * torch.sin(est_phase)
elif self.masking_mode == "C":
real, imag = (
real * mask_real - imag * mask_imag,
real * mask_imag + imag * mask_real,
)
elif self.masking_mode == "R":
real, imag = real * mask_real, imag * mask_imag
# shape (B, F, T) --> (B, T, F)
if is_torch_1_9_plus and self.use_builtin_complex:
masked.append(
torch.complex(real.permute(0, 2, 1), imag.permute(0, 2, 1))
)
else:
masked.append(
ComplexTensor(real.permute(0, 2, 1), imag.permute(0, 2, 1))
)
return masked
@property
def num_spk(self):
return self._num_spk