# An implementation of SkiM model described in
# "SkiM: Skipping Memory LSTM for Low-Latency Real-Time Continuous Speech Separation"
# (https://arxiv.org/abs/2201.10800)
#
import torch
import torch.nn as nn
from espnet2.enh.layers.dprnn import SingleRNN, merge_feature, split_feature
from espnet2.enh.layers.tcn import choose_norm
[docs]class MemLSTM(nn.Module):
"""the Mem-LSTM of SkiM
args:
hidden_size: int, dimension of the hidden state.
dropout: float, dropout ratio. Default is 0.
bidirectional: bool, whether the LSTM layers are bidirectional.
Default is False.
mem_type: 'hc', 'h', 'c' or 'id'.
It controls whether the hidden (or cell) state of
SegLSTM will be processed by MemLSTM.
In 'id' mode, both the hidden and cell states will
be identically returned.
norm_type: gLN, cLN. cLN is for causal implementation.
"""
def __init__(
self,
hidden_size,
dropout=0.0,
bidirectional=False,
mem_type="hc",
norm_type="cLN",
):
super().__init__()
self.hidden_size = hidden_size
self.bidirectional = bidirectional
self.input_size = (int(bidirectional) + 1) * hidden_size
self.mem_type = mem_type
assert mem_type in [
"hc",
"h",
"c",
"id",
], f"only support 'hc', 'h', 'c' and 'id', current type: {mem_type}"
if mem_type in ["hc", "h"]:
self.h_net = SingleRNN(
"LSTM",
input_size=self.input_size,
hidden_size=self.hidden_size,
dropout=dropout,
bidirectional=bidirectional,
)
self.h_norm = choose_norm(
norm_type=norm_type, channel_size=self.input_size, shape="BTD"
)
if mem_type in ["hc", "c"]:
self.c_net = SingleRNN(
"LSTM",
input_size=self.input_size,
hidden_size=self.hidden_size,
dropout=dropout,
bidirectional=bidirectional,
)
self.c_norm = choose_norm(
norm_type=norm_type, channel_size=self.input_size, shape="BTD"
)
[docs] def forward(self, hc, S):
# hc = (h, c), tuple of hidden and cell states from SegLSTM
# shape of h and c: (d, B*S, H)
# S: number of segments in SegLSTM
if self.mem_type == "id":
ret_val = hc
h, c = hc
d, BS, H = h.shape
B = BS // S
else:
h, c = hc
d, BS, H = h.shape
B = BS // S
h = h.transpose(1, 0).contiguous().view(B, S, d * H) # B, S, dH
c = c.transpose(1, 0).contiguous().view(B, S, d * H) # B, S, dH
if self.mem_type == "hc":
h = h + self.h_norm(self.h_net(h)[0])
c = c + self.c_norm(self.c_net(c)[0])
elif self.mem_type == "h":
h = h + self.h_norm(self.h_net(h)[0])
c = torch.zeros_like(c)
elif self.mem_type == "c":
h = torch.zeros_like(h)
c = c + self.c_norm(self.c_net(c)[0])
h = h.view(B * S, d, H).transpose(1, 0).contiguous()
c = c.view(B * S, d, H).transpose(1, 0).contiguous()
ret_val = (h, c)
if not self.bidirectional:
# for causal setup
causal_ret_val = []
for x in ret_val:
x = x.transpose(1, 0).contiguous().view(B, S, d * H)
x_ = torch.zeros_like(x)
x_[:, 1:, :] = x[:, :-1, :]
x_ = x_.view(B * S, d, H).transpose(1, 0).contiguous()
causal_ret_val.append(x_)
ret_val = tuple(causal_ret_val)
return ret_val
[docs] def forward_one_step(self, hc, state):
if self.mem_type == "id":
pass
else:
h, c = hc
d, B, H = h.shape
h = h.transpose(1, 0).contiguous().view(B, 1, d * H) # B, 1, dH
c = c.transpose(1, 0).contiguous().view(B, 1, d * H) # B, 1, dH
if self.mem_type == "hc":
h_tmp, state[0] = self.h_net(h, state[0])
h = h + self.h_norm(h_tmp)
c_tmp, state[1] = self.c_net(c, state[1])
c = c + self.c_norm(c_tmp)
elif self.mem_type == "h":
h_tmp, state[0] = self.h_net(h, state[0])
h = h + self.h_norm(h_tmp)
c = torch.zeros_like(c)
elif self.mem_type == "c":
h = torch.zeros_like(h)
c_tmp, state[1] = self.c_net(c, state[1])
c = c + self.c_norm(c_tmp)
h = h.transpose(1, 0).contiguous()
c = c.transpose(1, 0).contiguous()
hc = (h, c)
return hc, state
[docs]class SegLSTM(nn.Module):
"""the Seg-LSTM of SkiM
args:
input_size: int, dimension of the input feature.
The input should have shape (batch, seq_len, input_size).
hidden_size: int, dimension of the hidden state.
dropout: float, dropout ratio. Default is 0.
bidirectional: bool, whether the LSTM layers are bidirectional.
Default is False.
norm_type: gLN, cLN. cLN is for causal implementation.
"""
def __init__(
self, input_size, hidden_size, dropout=0.0, bidirectional=False, norm_type="cLN"
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_direction = int(bidirectional) + 1
self.lstm = nn.LSTM(
input_size,
hidden_size,
1,
batch_first=True,
bidirectional=bidirectional,
)
self.dropout = nn.Dropout(p=dropout)
self.proj = nn.Linear(hidden_size * self.num_direction, input_size)
self.norm = choose_norm(
norm_type=norm_type, channel_size=input_size, shape="BTD"
)
[docs] def forward(self, input, hc):
# input shape: B, T, H
B, T, H = input.shape
if hc is None:
# In fist input SkiM block, h and c are not available
d = self.num_direction
h = torch.zeros(d, B, self.hidden_size, dtype=input.dtype).to(input.device)
c = torch.zeros(d, B, self.hidden_size, dtype=input.dtype).to(input.device)
else:
h, c = hc
output, (h, c) = self.lstm(input, (h, c))
output = self.dropout(output)
output = self.proj(output.contiguous().view(-1, output.shape[2])).view(
input.shape
)
output = input + self.norm(output)
return output, (h, c)
[docs]class SkiM(nn.Module):
"""Skipping Memory Net
args:
input_size: int, dimension of the input feature.
Input shape shoud be (batch, length, input_size)
hidden_size: int, dimension of the hidden state.
output_size: int, dimension of the output size.
dropout: float, dropout ratio. Default is 0.
num_blocks: number of basic SkiM blocks
segment_size: segmentation size for splitting long features
bidirectional: bool, whether the RNN layers are bidirectional.
mem_type: 'hc', 'h', 'c', 'id' or None.
It controls whether the hidden (or cell) state of SegLSTM
will be processed by MemLSTM.
In 'id' mode, both the hidden and cell states will
be identically returned.
When mem_type is None, the MemLSTM will be removed.
norm_type: gLN, cLN. cLN is for causal implementation.
seg_overlap: Bool, whether the segmentation will reserve 50%
overlap for adjacent segments.Default is False.
"""
def __init__(
self,
input_size,
hidden_size,
output_size,
dropout=0.0,
num_blocks=2,
segment_size=20,
bidirectional=True,
mem_type="hc",
norm_type="gLN",
seg_overlap=False,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
self.segment_size = segment_size
self.dropout = dropout
self.num_blocks = num_blocks
self.mem_type = mem_type
self.norm_type = norm_type
self.seg_overlap = seg_overlap
assert mem_type in [
"hc",
"h",
"c",
"id",
None,
], f"only support 'hc', 'h', 'c', 'id', and None, current type: {mem_type}"
self.seg_lstms = nn.ModuleList([])
for i in range(num_blocks):
self.seg_lstms.append(
SegLSTM(
input_size=input_size,
hidden_size=hidden_size,
dropout=dropout,
bidirectional=bidirectional,
norm_type=norm_type,
)
)
if self.mem_type is not None:
self.mem_lstms = nn.ModuleList([])
for i in range(num_blocks - 1):
self.mem_lstms.append(
MemLSTM(
hidden_size,
dropout=dropout,
bidirectional=bidirectional,
mem_type=mem_type,
norm_type=norm_type,
)
)
self.output_fc = nn.Sequential(
nn.PReLU(), nn.Conv1d(input_size, output_size, 1)
)
[docs] def forward(self, input):
# input shape: B, T (S*K), D
B, T, D = input.shape
if self.seg_overlap:
input, rest = split_feature(
input.transpose(1, 2), segment_size=self.segment_size
) # B, D, K, S
input = input.permute(0, 3, 2, 1).contiguous() # B, S, K, D
else:
input, rest = self._padfeature(input=input)
input = input.view(B, -1, self.segment_size, D) # B, S, K, D
B, S, K, D = input.shape
assert K == self.segment_size
output = input.view(B * S, K, D).contiguous() # BS, K, D
hc = None
for i in range(self.num_blocks):
output, hc = self.seg_lstms[i](output, hc) # BS, K, D
if self.mem_type and i < self.num_blocks - 1:
hc = self.mem_lstms[i](hc, S)
pass
if self.seg_overlap:
output = output.view(B, S, K, D).permute(0, 3, 2, 1) # B, D, K, S
output = merge_feature(output, rest) # B, D, T
output = self.output_fc(output).transpose(1, 2)
else:
output = output.view(B, S * K, D)[:, :T, :] # B, T, D
output = self.output_fc(output.transpose(1, 2)).transpose(1, 2)
return output
def _padfeature(self, input):
B, T, D = input.shape
rest = self.segment_size - T % self.segment_size
if rest > 0:
input = torch.nn.functional.pad(input, (0, 0, 0, rest))
return input, rest
[docs] def forward_stream(self, input_frame, states):
# input_frame # B, 1, N
B, _, N = input_frame.shape
def empty_seg_states():
shp = (1, B, self.hidden_size)
return (
torch.zeros(*shp, device=input_frame.device, dtype=input_frame.dtype),
torch.zeros(*shp, device=input_frame.device, dtype=input_frame.dtype),
)
B, _, N = input_frame.shape
if not states:
states = {
"current_step": 0,
"seg_state": [empty_seg_states() for i in range(self.num_blocks)],
"mem_state": [[None, None] for i in range(self.num_blocks - 1)],
}
output = input_frame
if states["current_step"] and (states["current_step"]) % self.segment_size == 0:
tmp_states = [empty_seg_states() for i in range(self.num_blocks)]
for i in range(self.num_blocks - 1):
tmp_states[i + 1], states["mem_state"][i] = self.mem_lstms[
i
].forward_one_step(states["seg_state"][i], states["mem_state"][i])
states["seg_state"] = tmp_states
for i in range(self.num_blocks):
output, states["seg_state"][i] = self.seg_lstms[i](
output, states["seg_state"][i]
)
states["current_step"] += 1
output = self.output_fc(output.transpose(1, 2)).transpose(1, 2)
return output, states
if __name__ == "__main__":
torch.manual_seed(111)
seq_len = 100
model = SkiM(
16,
11,
16,
dropout=0.0,
num_blocks=4,
segment_size=20,
bidirectional=False,
mem_type="hc",
norm_type="cLN",
seg_overlap=False,
)
model.eval()
input = torch.randn(3, seq_len, 16)
seg_output = model(input)
state = None
for i in range(seq_len):
input_frame = input[:, i : i + 1, :]
output, state = model.forward_stream(input_frame=input_frame, states=state)
torch.testing.assert_allclose(output, seg_output[:, i : i + 1, :])
print("streaming ok")