from collections import OrderedDict
from distutils.version import LooseVersion
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch_complex.tensor import ComplexTensor
from espnet2.enh.layers.complex_utils import is_complex
from espnet2.enh.layers.dptnet import DPTNet
from espnet2.enh.layers.tcn import choose_norm
from espnet2.enh.separator.abs_separator import AbsSeparator
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
[docs]class DPTNetSeparator(AbsSeparator):
def __init__(
self,
input_dim: int,
post_enc_relu: bool = True,
rnn_type: str = "lstm",
bidirectional: bool = True,
num_spk: int = 2,
predict_noise: bool = False,
unit: int = 256,
att_heads: int = 4,
dropout: float = 0.0,
activation: str = "relu",
norm_type: str = "gLN",
layer: int = 6,
segment_size: int = 20,
nonlinear: str = "relu",
):
"""Dual-Path Transformer Network (DPTNet) Separator
Args:
input_dim: input feature dimension
rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
bidirectional: bool, whether the inter-chunk RNN layers are bidirectional.
num_spk: number of speakers
predict_noise: whether to output the estimated noise signal
unit: int, dimension of the hidden state.
att_heads: number of attention heads.
dropout: float, dropout ratio. Default is 0.
activation: activation function applied at the output of RNN.
norm_type: type of normalization to use after each inter- or
intra-chunk Transformer block.
nonlinear: the nonlinear function for mask estimation,
select from 'relu', 'tanh', 'sigmoid'
layer: int, number of stacked RNN layers. Default is 3.
segment_size: dual-path segment size
"""
super().__init__()
self._num_spk = num_spk
self.predict_noise = predict_noise
self.segment_size = segment_size
self.post_enc_relu = post_enc_relu
self.enc_LN = choose_norm(norm_type, input_dim)
self.num_outputs = self.num_spk + 1 if self.predict_noise else self.num_spk
self.dptnet = DPTNet(
rnn_type=rnn_type,
input_size=input_dim,
hidden_size=unit,
output_size=input_dim * self.num_outputs,
att_heads=att_heads,
dropout=dropout,
activation=activation,
num_layers=layer,
bidirectional=bidirectional,
norm_type=norm_type,
)
# gated output layer
self.output = torch.nn.Sequential(
torch.nn.Conv1d(input_dim, input_dim, 1), torch.nn.Tanh()
)
self.output_gate = torch.nn.Sequential(
torch.nn.Conv1d(input_dim, input_dim, 1), torch.nn.Sigmoid()
)
if nonlinear not in ("sigmoid", "relu", "tanh"):
raise ValueError("Not supporting nonlinear={}".format(nonlinear))
self.nonlinear = {
"sigmoid": torch.nn.Sigmoid(),
"relu": torch.nn.ReLU(),
"tanh": torch.nn.Tanh(),
}[nonlinear]
[docs] def forward(
self,
input: Union[torch.Tensor, ComplexTensor],
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
"""Forward.
Args:
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
ilens (torch.Tensor): (B,)
others predicted data, e.g. masks: OrderedDict[
'mask_spk1': torch.Tensor(Batch, Frames, Freq),
'mask_spk2': torch.Tensor(Batch, Frames, Freq),
...
'mask_spkn': torch.Tensor(Batch, Frames, Freq),
]
"""
# if complex spectrum,
if is_complex(input):
feature = abs(input)
elif self.post_enc_relu:
feature = torch.nn.functional.relu(input)
else:
feature = input
B, T, N = feature.shape
feature = feature.transpose(1, 2) # B, N, T
feature = self.enc_LN(feature)
segmented = self.split_feature(feature) # B, N, L, K
processed = self.dptnet(segmented) # B, N*num_spk, L, K
processed = processed.reshape(
B * self.num_outputs, -1, processed.size(-2), processed.size(-1)
) # B*num_spk, N, L, K
processed = self.merge_feature(processed, length=T) # B*num_spk, N, T
# gated output layer for filter generation (B*num_spk, N, T)
processed = self.output(processed) * self.output_gate(processed)
masks = processed.reshape(B, self.num_outputs, N, T)
# list[(B, T, N)]
masks = self.nonlinear(masks.transpose(-1, -2)).unbind(dim=1)
if self.predict_noise:
*masks, mask_noise = masks
masked = [input * m for m in masks]
others = OrderedDict(
zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks)
)
if self.predict_noise:
others["noise1"] = input * mask_noise
return masked, ilens, others
[docs] def split_feature(self, x):
B, N, T = x.size()
unfolded = torch.nn.functional.unfold(
x.unsqueeze(-1),
kernel_size=(self.segment_size, 1),
padding=(self.segment_size, 0),
stride=(self.segment_size // 2, 1),
)
return unfolded.reshape(B, N, self.segment_size, -1)
[docs] def merge_feature(self, x, length=None):
B, N, L, n_chunks = x.size()
hop_size = self.segment_size // 2
if length is None:
length = (n_chunks - 1) * hop_size + L
padding = 0
else:
padding = (0, L)
seq = x.reshape(B, N * L, n_chunks)
x = torch.nn.functional.fold(
seq,
output_size=(1, length),
kernel_size=(1, L),
padding=padding,
stride=(1, hop_size),
)
norm_mat = torch.nn.functional.fold(
input=torch.ones_like(seq),
output_size=(1, length),
kernel_size=(1, L),
padding=padding,
stride=(1, hop_size),
)
x /= norm_mat
return x.reshape(B, N, length)
@property
def num_spk(self):
return self._num_spk