Source code for espnet2.enh.layers.dptnet
# The implementation of DPTNet proposed in
# J. Chen, Q. Mao, and D. Liu, “Dual-path transformer network:
# Direct context-aware modeling for end-to-end monaural speech
# separation,” in Proc. ISCA Interspeech, 2020, pp. 2642–2646.
#
# Ported from https://github.com/ujscjj/DPTNet
import torch.nn as nn
from espnet2.enh.layers.tcn import choose_norm
from espnet.nets.pytorch_backend.nets_utils import get_activation
[docs]class ImprovedTransformerLayer(nn.Module):
"""Container module of the (improved) Transformer proposed in [1].
Reference:
Dual-path transformer network: Direct context-aware modeling for end-to-end
monaural speech separation; Chen et al, Interspeech 2020.
Args:
rnn_type (str): select from 'RNN', 'LSTM' and 'GRU'.
input_size (int): Dimension of the input feature.
att_heads (int): Number of attention heads.
hidden_size (int): Dimension of the hidden state.
dropout (float): Dropout ratio. Default is 0.
activation (str): activation function applied at the output of RNN.
bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
(Intra-Chunk is always bidirectional).
norm (str, optional): Type of normalization to use.
"""
def __init__(
self,
rnn_type,
input_size,
att_heads,
hidden_size,
dropout=0.0,
activation="relu",
bidirectional=True,
norm="gLN",
):
super().__init__()
rnn_type = rnn_type.upper()
assert rnn_type in [
"RNN",
"LSTM",
"GRU",
], f"Only support 'RNN', 'LSTM' and 'GRU', current type: {rnn_type}"
self.rnn_type = rnn_type
self.att_heads = att_heads
self.self_attn = nn.MultiheadAttention(input_size, att_heads, dropout=dropout)
self.dropout = nn.Dropout(p=dropout)
self.norm_attn = choose_norm(norm, input_size)
self.rnn = getattr(nn, rnn_type)(
input_size,
hidden_size,
1,
batch_first=True,
bidirectional=bidirectional,
)
activation = get_activation(activation)
hdim = 2 * hidden_size if bidirectional else hidden_size
self.feed_forward = nn.Sequential(
activation, nn.Dropout(p=dropout), nn.Linear(hdim, input_size)
)
self.norm_ff = choose_norm(norm, input_size)
[docs] def forward(self, x, attn_mask=None):
# (batch, seq, input_size) -> (seq, batch, input_size)
src = x.permute(1, 0, 2)
# (seq, batch, input_size) -> (batch, seq, input_size)
out = self.self_attn(src, src, src, attn_mask=attn_mask)[0].permute(1, 0, 2)
out = self.dropout(out) + x
# ... -> (batch, input_size, seq) -> ...
out = self.norm_attn(out.transpose(-1, -2)).transpose(-1, -2)
out2 = self.feed_forward(self.rnn(out)[0])
out2 = self.dropout(out2) + out
return self.norm_ff(out2.transpose(-1, -2)).transpose(-1, -2)
[docs]class DPTNet(nn.Module):
"""Dual-path transformer network.
args:
rnn_type (str): select from 'RNN', 'LSTM' and 'GRU'.
input_size (int): dimension of the input feature.
Input size must be a multiple of `att_heads`.
hidden_size (int): dimension of the hidden state.
output_size (int): dimension of the output size.
att_heads (int): number of attention heads.
dropout (float): dropout ratio. Default is 0.
activation (str): activation function applied at the output of RNN.
num_layers (int): number of stacked RNN layers. Default is 1.
bidirectional (bool): whether the RNN layers are bidirectional. Default is True.
norm_type (str): type of normalization to use after each inter- or
intra-chunk Transformer block.
"""
def __init__(
self,
rnn_type,
input_size,
hidden_size,
output_size,
att_heads=4,
dropout=0,
activation="relu",
num_layers=1,
bidirectional=True,
norm_type="gLN",
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
# dual-path transformer
self.row_transformer = nn.ModuleList()
self.col_transformer = nn.ModuleList()
for i in range(num_layers):
self.row_transformer.append(
ImprovedTransformerLayer(
rnn_type,
input_size,
att_heads,
hidden_size,
dropout=dropout,
activation=activation,
bidirectional=True,
norm=norm_type,
)
) # intra-segment RNN is always noncausal
self.col_transformer.append(
ImprovedTransformerLayer(
rnn_type,
input_size,
att_heads,
hidden_size,
dropout=dropout,
activation=activation,
bidirectional=bidirectional,
norm=norm_type,
)
)
# output layer
self.output = nn.Sequential(nn.PReLU(), nn.Conv2d(input_size, output_size, 1))
[docs] def forward(self, input):
# input shape: batch, N, dim1, dim2
# apply Transformer on dim1 first and then dim2
# output shape: B, output_size, dim1, dim2
# input = input.to(device)
output = input
for i in range(len(self.row_transformer)):
output = self.intra_chunk_process(output, i)
output = self.inter_chunk_process(output, i)
output = self.output(output) # B, output_size, dim1, dim2
return output
[docs] def intra_chunk_process(self, x, layer_index):
batch, N, chunk_size, n_chunks = x.size()
x = x.transpose(1, -1).reshape(batch * n_chunks, chunk_size, N)
x = self.row_transformer[layer_index](x)
x = x.reshape(batch, n_chunks, chunk_size, N).permute(0, 3, 2, 1)
return x
[docs] def inter_chunk_process(self, x, layer_index):
batch, N, chunk_size, n_chunks = x.size()
x = x.permute(0, 2, 3, 1).reshape(batch * chunk_size, n_chunks, N)
x = self.col_transformer[layer_index](x)
x = x.reshape(batch, chunk_size, n_chunks, N).permute(0, 3, 1, 2)
return x