import torch
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor
from espnet2.torch_utils.get_layer_from_string import get_layer
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
[docs]class Conv2DActNorm(torch.nn.Module):
"""Basic Conv2D + activation + instance norm building block."""
def __init__(
self,
in_channels,
out_channels,
ksz=(3, 3),
stride=(1, 2),
padding=(1, 0),
upsample=False,
activation=torch.nn.ELU,
):
super(Conv2DActNorm, self).__init__()
if upsample:
conv = torch.nn.ConvTranspose2d(
in_channels, out_channels, ksz, stride, padding
)
else:
conv = torch.nn.Conv2d(
in_channels, out_channels, ksz, stride, padding, padding_mode="reflect"
)
act = get_layer(activation)()
norm = torch.nn.GroupNorm(out_channels, out_channels, eps=1e-8)
self.layer = torch.nn.Sequential(conv, act, norm)
[docs] def forward(self, inp):
return self.layer(inp)
[docs]class FreqWiseBlock(torch.nn.Module):
"""FreqWiseBlock, see iNeuBe paper.
Block that applies pointwise 2D convolution over
STFT-like image tensor on frequency axis.
The input is assumed to be [batch, image_channels, frames, freq].
"""
def __init__(self, in_channels, num_freqs, out_channels, activation=torch.nn.ELU):
super(FreqWiseBlock, self).__init__()
self.bottleneck = Conv2DActNorm(
in_channels, out_channels, (1, 1), (1, 1), (0, 0), activation=activation
)
self.freq_proc = Conv2DActNorm(
num_freqs, num_freqs, (1, 1), (1, 1), (0, 0), activation=activation
)
[docs] def forward(self, inp):
# bsz, chans, x, y
out = self.freq_proc(self.bottleneck(inp).permute(0, 3, 2, 1)).permute(
0, 3, 2, 1
)
return out
[docs]class DenseBlock(torch.nn.Module):
"""single DenseNet block as used in iNeuBe model.
Args:
in_channels: number of input channels (image axis).
out_channels: number of output channels (image axis).
num_freqs: number of complex frequencies in the
input STFT complex image-like tensor.
The input is batch, image_channels, frames, freqs.
pre_blocks: dense block before point-wise convolution block over frequency axis.
freq_proc_blocks: number of frequency axis processing blocks.
post_blocks: dense block after point-wise convolution block over frequency axis.
ksz: kernel size used in densenet Conv2D layers.
activation: activation function to use in the whole iNeuBe model,
you can use any torch supported activation e.g. 'relu' or 'elu'.
hid_chans: number of hidden channels in densenet Conv2D.
"""
def __init__(
self,
in_channels,
out_channels,
num_freqs,
pre_blocks=2,
freq_proc_blocks=1,
post_blocks=2,
ksz=(3, 3),
activation=torch.nn.ELU,
hid_chans=32,
):
super(DenseBlock, self).__init__()
assert post_blocks >= 1
assert pre_blocks >= 1
self.pre_blocks = torch.nn.ModuleList([])
tot_layers = 0
for indx in range(pre_blocks):
c_layer = Conv2DActNorm(
in_channels + hid_chans * tot_layers,
hid_chans,
ksz,
(1, 1),
(1, 1),
activation=activation,
)
self.pre_blocks.append(c_layer)
tot_layers += 1
self.freq_proc_blocks = torch.nn.ModuleList([])
for indx in range(freq_proc_blocks):
c_layer = FreqWiseBlock(
in_channels + hid_chans * tot_layers,
num_freqs,
hid_chans,
activation=activation,
)
self.freq_proc_blocks.append(c_layer)
tot_layers += 1
self.post_blocks = torch.nn.ModuleList([])
for indx in range(post_blocks - 1):
c_layer = Conv2DActNorm(
in_channels + hid_chans * tot_layers,
hid_chans,
ksz,
(1, 1),
(1, 1),
activation=activation,
)
self.post_blocks.append(c_layer)
tot_layers += 1
last = Conv2DActNorm(
in_channels + hid_chans * tot_layers,
out_channels,
ksz,
(1, 1),
(1, 1),
activation=activation,
)
self.post_blocks.append(last)
[docs] def forward(self, input):
# batch, channels, frames, freq
out = [input]
for pre_block in self.pre_blocks:
c_out = pre_block(torch.cat(out, 1))
out.append(c_out)
for freq_block in self.freq_proc_blocks:
c_out = freq_block(torch.cat(out, 1))
out.append(c_out)
for post_block in self.post_blocks:
c_out = post_block(torch.cat(out, 1))
out.append(c_out)
return c_out
[docs]class TCNResBlock(torch.nn.Module):
"""single depth-wise separable TCN block as used in iNeuBe TCN.
Args:
in_chan: number of input feature channels.
out_chan: number of output feature channels.
ksz: kernel size.
stride: stride in depth-wise convolution.
dilation: dilation in depth-wise convolution.
activation: activation function to use in the whole iNeuBe model,
you can use any torch supported activation e.g. 'relu' or 'elu'.
"""
def __init__(
self, in_chan, out_chan, ksz=3, stride=1, dilation=1, activation=torch.nn.ELU
):
super(TCNResBlock, self).__init__()
padding = dilation
dconv = torch.nn.Conv1d(
in_chan,
in_chan,
ksz,
stride,
padding=padding,
dilation=dilation,
padding_mode="reflect",
groups=in_chan,
)
point_conv = torch.nn.Conv1d(in_chan, out_chan, 1)
self.layer = torch.nn.Sequential(
torch.nn.GroupNorm(in_chan, in_chan, eps=1e-8),
get_layer(activation)(),
dconv,
point_conv,
)
[docs] def forward(self, inp):
# [B, C, F] batch, channels, frames
return self.layer(inp) + inp
[docs]class TCNDenseUNet(torch.nn.Module):
"""TCNDenseNet block from iNeuBe
Reference:
Lu, Y. J., Cornell, S., Chang, X., Zhang, W., Li, C., Ni, Z., ... & Watanabe, S.
Towards Low-Distortion Multi-Channel Speech Enhancement:
The ESPNET-Se Submission to the L3DAS22 Challenge. ICASSP 2022 p. 9201-9205.
Args:
n_spk: number of output sources/speakers.
in_freqs: number of complex STFT frequencies.
mic_channels: number of microphones channels
(only fixed-array geometry supported).
hid_chans: number of channels in the subsampling/upsampling conv layers.
hid_chans_dense: number of channels in the densenet layers
(reduce this to reduce VRAM requirements).
ksz_dense: kernel size in the densenet layers thorough iNeuBe.
ksz_tcn: kernel size in the TCN submodule.
tcn_repeats: number of repetitions of blocks in the TCN submodule.
tcn_blocks: number of blocks in the TCN submodule.
tcn_channels: number of channels in the TCN submodule.
activation: activation function to use in the whole iNeuBe model,
you can use any torch supported activation e.g. 'relu' or 'elu'.
"""
def __init__(
self,
n_spk=1,
in_freqs=257,
mic_channels=1,
hid_chans=32,
hid_chans_dense=32,
ksz_dense=(3, 3),
ksz_tcn=3,
tcn_repeats=4,
tcn_blocks=7,
tcn_channels=384,
activation=torch.nn.ELU,
):
super(TCNDenseUNet, self).__init__()
self.n_spk = n_spk
self.in_channels = in_freqs
self.mic_channels = mic_channels
num_freqs = in_freqs - 2
first = torch.nn.Sequential(
torch.nn.Conv2d(
self.mic_channels * 2,
hid_chans,
(3, 3),
(1, 1),
(1, 0),
padding_mode="reflect",
),
DenseBlock(
hid_chans,
hid_chans,
num_freqs,
ksz=ksz_dense,
activation=activation,
hid_chans=hid_chans_dense,
),
)
freq_axis_dims = self._get_depth(num_freqs)
self.encoder = torch.nn.ModuleList([])
self.encoder.append(first)
for layer_indx in range(len(freq_axis_dims)):
downsample = Conv2DActNorm(
hid_chans, hid_chans, (3, 3), (1, 2), (1, 0), activation=activation
)
denseblocks = DenseBlock(
hid_chans,
hid_chans,
freq_axis_dims[layer_indx],
ksz=ksz_dense,
activation=activation,
hid_chans=hid_chans_dense,
)
c_layer = torch.nn.Sequential(downsample, denseblocks)
self.encoder.append(c_layer)
self.encoder.append(
Conv2DActNorm(
hid_chans, hid_chans * 2, (3, 3), (1, 2), (1, 0), activation=activation
)
)
self.encoder.append(
Conv2DActNorm(
hid_chans * 2,
hid_chans * 4,
(3, 3),
(1, 2),
(1, 0),
activation=activation,
)
)
self.encoder.append(
Conv2DActNorm(
hid_chans * 4,
tcn_channels,
(3, 3),
(1, 1),
(1, 0),
activation=activation,
)
)
self.tcn = []
for r in range(tcn_repeats):
for x in range(tcn_blocks):
self.tcn.append(
TCNResBlock(
tcn_channels,
tcn_channels,
ksz_tcn,
dilation=2**x,
activation=activation,
)
)
self.tcn = torch.nn.Sequential(*self.tcn)
self.decoder = torch.nn.ModuleList([])
self.decoder.append(
Conv2DActNorm(
tcn_channels * 2,
hid_chans * 4,
(3, 3),
(1, 1),
(1, 0),
activation=activation,
upsample=True,
)
)
self.decoder.append(
Conv2DActNorm(
hid_chans * 8,
hid_chans * 2,
(3, 3),
(1, 2),
(1, 0),
activation=activation,
upsample=True,
)
)
self.decoder.append(
Conv2DActNorm(
hid_chans * 4,
hid_chans,
(3, 3),
(1, 2),
(1, 0),
activation=activation,
upsample=True,
)
)
for dec_indx in range(len(freq_axis_dims)):
c_num_freqs = freq_axis_dims[len(freq_axis_dims) - dec_indx - 1]
denseblocks = DenseBlock(
hid_chans * 2,
hid_chans * 2,
c_num_freqs,
ksz=ksz_dense,
activation=activation,
hid_chans=hid_chans_dense,
)
upsample = Conv2DActNorm(
hid_chans * 2,
hid_chans,
(3, 3),
(1, 2),
(1, 0),
activation=activation,
upsample=True,
)
c_layer = torch.nn.Sequential(denseblocks, upsample)
self.decoder.append(c_layer)
last = torch.nn.Sequential(
DenseBlock(
hid_chans * 2,
hid_chans * 2,
self.in_channels - 2,
ksz=ksz_dense,
activation=activation,
hid_chans=hid_chans_dense,
),
torch.nn.ConvTranspose2d(
hid_chans * 2, 2 * self.n_spk, (3, 3), (1, 1), (1, 0)
),
)
self.decoder.append(last)
def _get_depth(self, num_freq):
n_layers = 0
freqs = []
while num_freq > 15:
num_freq = int(num_freq / 2)
freqs.append(num_freq)
n_layers += 1
return freqs
[docs] def forward(self, tf_rep):
"""forward.
Args:
tf_rep (torch.Tensor): 4D tensor (multi-channel complex STFT of mixture)
of shape [B, T, C, F] batch, frames, microphones, frequencies.
Returns:
out (torch.Tensor): complex 3D tensor monaural STFT of the targets
shape is [B, T, F] batch, frames, frequencies.
"""
# B, T, C, F
tf_rep = tf_rep.permute(0, 2, 3, 1)
bsz, mics, _, frames = tf_rep.shape
assert mics == self.mic_channels
inp_feats = torch.cat((tf_rep.real, tf_rep.imag), 1)
inp_feats = inp_feats.transpose(-1, -2)
inp_feats = inp_feats.reshape(
bsz, self.mic_channels * 2, frames, self.in_channels
)
enc_out = []
buffer = inp_feats
for enc_layer in self.encoder:
buffer = enc_layer(buffer)
enc_out.append(buffer)
assert buffer.shape[-1] == 1
tcn_out = self.tcn(buffer.squeeze(-1)).unsqueeze(-1)
buffer = tcn_out
for indx, dec_layer in enumerate(self.decoder):
c_input = torch.cat((buffer, enc_out[-(indx + 1)]), 1)
buffer = dec_layer(c_input)
buffer = buffer.reshape(bsz, 2, self.n_spk, -1, self.in_channels)
if is_torch_1_9_plus:
out = torch.complex(buffer[:, 0], buffer[:, 1])
else:
out = ComplexTensor(buffer[:, 0], buffer[:, 1])
# bsz, complex_chans, frames or bsz, spk, complex_chans, frames
return out # bsz, spk, time, freq -> bsz, time, spk, freq