Source code for espnet2.enh.encoder.stft_encoder

import torch
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor

from espnet2.enh.encoder.abs_encoder import AbsEncoder
from espnet2.layers.stft import Stft

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


[docs]class STFTEncoder(AbsEncoder): """STFT encoder for speech enhancement and separation""" def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, window="hann", center: bool = True, normalized: bool = False, onesided: bool = True, use_builtin_complex: bool = True, ): super().__init__() self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) self._output_dim = n_fft // 2 + 1 if onesided else n_fft self.use_builtin_complex = use_builtin_complex self.win_length = win_length if win_length else n_fft self.hop_length = hop_length self.window = window self.n_fft = n_fft self.center = center @property def output_dim(self) -> int: return self._output_dim
[docs] def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, sample] ilens (torch.Tensor): input lengths [Batch] """ # for supporting half-precision training if input.dtype in (torch.float16, torch.bfloat16): spectrum, flens = self.stft(input.float(), ilens) spectrum = spectrum.to(dtype=input.dtype) else: spectrum, flens = self.stft(input, ilens) if is_torch_1_9_plus and self.use_builtin_complex: spectrum = torch.complex(spectrum[..., 0], spectrum[..., 1]) else: spectrum = ComplexTensor(spectrum[..., 0], spectrum[..., 1]) return spectrum, flens
def _apply_window_func(self, input): B = input.shape[0] window_func = getattr(torch, f"{self.window}_window") window = window_func(self.win_length, dtype=input.dtype, device=input.device) n_pad_left = (self.n_fft - window.shape[0]) // 2 n_pad_right = self.n_fft - window.shape[0] - n_pad_left windowed = input * window windowed = torch.cat( [torch.zeros(B, n_pad_left), windowed, torch.zeros(B, n_pad_right)], 1 ) return windowed
[docs] def forward_streaming(self, input: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, frame_length] Return: B, 1, F """ assert ( input.dim() == 2 ), "forward_streaming only support for single-channel input currently." windowed = self._apply_window_func(input) feature = ( torch.fft.rfft(windowed) if self.stft.onesided else torch.fft.fft(windowed) ) feature = feature.unsqueeze(1) if not (is_torch_1_9_plus and self.use_builtin_complex): feature = ComplexTensor(feature.real, feature.imag) return feature
[docs] def streaming_frame(self, audio): """streaming_frame. It splits the continuous audio into frame-level audio chunks in the streaming *simulation*. It is noted that this function takes the entire long audio as input for a streaming simulation. You may refer to this function to manage your streaming input buffer in a real streaming application. Args: audio: (B, T) Returns: chunked: List [(B, frame_size),] """ if self.center: pad_len = int(self.win_length // 2) signal_dim = audio.dim() extended_shape = [1] * (3 - signal_dim) + list(audio.size()) # the default STFT pad mode is "reflect", # which is not configurable in STFT encoder, # so, here we just use "reflect mode" audio = torch.nn.functional.pad( audio.view(extended_shape), [pad_len, pad_len], "reflect" ) audio = audio.view(audio.shape[-signal_dim:]) _, audio_len = audio.shape n_frames = 1 + (audio_len - self.win_length) // self.hop_length strides = list(audio.stride()) shape = list(audio.shape[:-1]) + [self.win_length, n_frames] strides = strides + [self.hop_length] return audio.as_strided(shape, strides, storage_offset=0).unbind(dim=-1)