Source code for espnet2.enh.decoder.stft_decoder

import math

import torch
import torch_complex
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor

from espnet2.enh.decoder.abs_decoder import AbsDecoder
from espnet2.enh.layers.complex_utils import is_torch_complex_tensor
from espnet2.layers.stft import Stft

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


[docs]class STFTDecoder(AbsDecoder): """STFT decoder for speech enhancement and separation""" def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, window="hann", center: bool = True, normalized: bool = False, onesided: bool = True, ): super().__init__() self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) self.win_length = win_length if win_length else n_fft self.n_fft = n_fft self.hop_length = hop_length self.window = window self.center = center
[docs] def forward(self, input: ComplexTensor, ilens: torch.Tensor): """Forward. Args: input (ComplexTensor): spectrum [Batch, T, (C,) F] ilens (torch.Tensor): input lengths [Batch] """ if not isinstance(input, ComplexTensor) and ( is_torch_1_9_plus and not torch.is_complex(input) ): raise TypeError("Only support complex tensors for stft decoder") bs = input.size(0) if input.dim() == 4: multi_channel = True # input: (Batch, T, C, F) -> (Batch * C, T, F) input = input.transpose(1, 2).reshape(-1, input.size(1), input.size(3)) else: multi_channel = False # for supporting half-precision training if input.dtype in (torch.float16, torch.bfloat16): wav, wav_lens = self.stft.inverse(input.float(), ilens) wav = wav.to(dtype=input.dtype) elif ( is_torch_complex_tensor(input) and hasattr(torch, "complex32") and input.dtype == torch.complex32 ): wav, wav_lens = self.stft.inverse(input.cfloat(), ilens) wav = wav.to(dtype=input.dtype) else: wav, wav_lens = self.stft.inverse(input, ilens) if multi_channel: # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C) wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2) return wav, wav_lens
def _get_window_func(self): window_func = getattr(torch, f"{self.window}_window") window = window_func(self.win_length) n_pad_left = (self.n_fft - window.shape[0]) // 2 n_pad_right = self.n_fft - window.shape[0] - n_pad_left return window
[docs] def forward_streaming(self, input_frame: torch.Tensor): """Forward. Args: input (ComplexTensor): spectrum [Batch, 1, F] output: wavs [Batch, 1, self.win_length] """ input_frame = input_frame.real + 1j * input_frame.imag output_wav = ( torch.fft.irfft(input_frame) if self.stft.onesided else torch.fft.ifft(input_frame).real ) output_wav = output_wav.squeeze(1) n_pad_left = (self.n_fft - self.win_length) // 2 output_wav = output_wav[..., n_pad_left : n_pad_left + self.win_length] return output_wav * self._get_window_func()
[docs] def streaming_merge(self, chunks, ilens=None): """streaming_merge. It merges the frame-level processed audio chunks in the streaming *simulation*. It is noted that, in real applications, the processed audio should be sent to the output channel frame by frame. You may refer to this function to manage your streaming output buffer. Args: chunks: List [(B, frame_size),] ilens: [B] Returns: merge_audio: [B, T] """ frame_size = self.win_length hop_size = self.hop_length num_chunks = len(chunks) batch_size = chunks[0].shape[0] audio_len = int(hop_size * num_chunks + frame_size - hop_size) output = torch.zeros((batch_size, audio_len), dtype=chunks[0].dtype).to( chunks[0].device ) for i, chunk in enumerate(chunks): output[:, i * hop_size : i * hop_size + frame_size] += chunk window_sq = self._get_window_func().pow(2) window_envelop = torch.zeros((batch_size, audio_len), dtype=chunks[0].dtype).to( chunks[0].device ) for i in range(len(chunks)): window_envelop[:, i * hop_size : i * hop_size + frame_size] += window_sq output = output / window_envelop # We need to trim the front padding away if center. start = (frame_size // 2) if self.center else 0 end = -(frame_size // 2) if ilens.max() is None else start + ilens.max() return output[..., start:end]
if __name__ == "__main__": from espnet2.enh.encoder.stft_encoder import STFTEncoder input_audio = torch.randn((1, 100)) ilens = torch.LongTensor([100]) nfft = 32 win_length = 28 hop = 10 encoder = STFTEncoder( n_fft=nfft, win_length=win_length, hop_length=hop, onesided=True ) decoder = STFTDecoder( n_fft=nfft, win_length=win_length, hop_length=hop, onesided=True ) frames, flens = encoder(input_audio, ilens) wav, ilens = decoder(frames, ilens) splited = encoder.streaming_frame(input_audio) sframes = [encoder.forward_streaming(s) for s in splited] swavs = [decoder.forward_streaming(s) for s in sframes] merged = decoder.streaming_merge(swavs, ilens) if not (is_torch_1_9_plus and encoder.use_builtin_complex): sframes = torch_complex.cat(sframes, dim=1) else: sframes = torch.cat(sframes, dim=1) torch.testing.assert_close(sframes.real, frames.real) torch.testing.assert_close(sframes.imag, frames.imag) torch.testing.assert_close(wav, input_audio) torch.testing.assert_close(wav, merged)