import torch
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor
from espnet2.enh.encoder.abs_encoder import AbsEncoder
from espnet2.layers.stft import Stft
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
[docs]class STFTEncoder(AbsEncoder):
"""STFT encoder for speech enhancement and separation"""
def __init__(
self,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window="hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
use_builtin_complex: bool = True,
):
super().__init__()
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
)
self._output_dim = n_fft // 2 + 1 if onesided else n_fft
self.use_builtin_complex = use_builtin_complex
self.win_length = win_length if win_length else n_fft
self.hop_length = hop_length
self.window = window
self.n_fft = n_fft
self.center = center
@property
def output_dim(self) -> int:
return self._output_dim
[docs] def forward(self, input: torch.Tensor, ilens: torch.Tensor):
"""Forward.
Args:
input (torch.Tensor): mixed speech [Batch, sample]
ilens (torch.Tensor): input lengths [Batch]
"""
# for supporting half-precision training
if input.dtype in (torch.float16, torch.bfloat16):
spectrum, flens = self.stft(input.float(), ilens)
spectrum = spectrum.to(dtype=input.dtype)
else:
spectrum, flens = self.stft(input, ilens)
if is_torch_1_9_plus and self.use_builtin_complex:
spectrum = torch.complex(spectrum[..., 0], spectrum[..., 1])
else:
spectrum = ComplexTensor(spectrum[..., 0], spectrum[..., 1])
return spectrum, flens
def _apply_window_func(self, input):
B = input.shape[0]
window_func = getattr(torch, f"{self.window}_window")
window = window_func(self.win_length, dtype=input.dtype, device=input.device)
n_pad_left = (self.n_fft - window.shape[0]) // 2
n_pad_right = self.n_fft - window.shape[0] - n_pad_left
windowed = input * window
windowed = torch.cat(
[torch.zeros(B, n_pad_left), windowed, torch.zeros(B, n_pad_right)], 1
)
return windowed
[docs] def forward_streaming(self, input: torch.Tensor):
"""Forward.
Args:
input (torch.Tensor): mixed speech [Batch, frame_length]
Return:
B, 1, F
"""
assert (
input.dim() == 2
), "forward_streaming only support for single-channel input currently."
windowed = self._apply_window_func(input)
feature = (
torch.fft.rfft(windowed) if self.stft.onesided else torch.fft.fft(windowed)
)
feature = feature.unsqueeze(1)
if not (is_torch_1_9_plus and self.use_builtin_complex):
feature = ComplexTensor(feature.real, feature.imag)
return feature
[docs] def streaming_frame(self, audio):
"""streaming_frame. It splits the continuous audio into frame-level
audio chunks in the streaming *simulation*. It is noted that this
function takes the entire long audio as input for a streaming simulation.
You may refer to this function to manage your streaming input
buffer in a real streaming application.
Args:
audio: (B, T)
Returns:
chunked: List [(B, frame_size),]
"""
if self.center:
pad_len = int(self.win_length // 2)
signal_dim = audio.dim()
extended_shape = [1] * (3 - signal_dim) + list(audio.size())
# the default STFT pad mode is "reflect",
# which is not configurable in STFT encoder,
# so, here we just use "reflect mode"
audio = torch.nn.functional.pad(
audio.view(extended_shape), [pad_len, pad_len], "reflect"
)
audio = audio.view(audio.shape[-signal_dim:])
_, audio_len = audio.shape
n_frames = 1 + (audio_len - self.win_length) // self.hop_length
strides = list(audio.stride())
shape = list(audio.shape[:-1]) + [self.win_length, n_frames]
strides = strides + [self.hop_length]
return audio.as_strided(shape, strides, storage_offset=0).unbind(dim=-1)