Source code for espnet2.enh.layers.wpe

from typing import Tuple, Union

import torch
import torch.nn.functional as F
import torch_complex.functional as FC
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor

from espnet2.enh.layers.complex_utils import einsum, matmul, reverse

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


""" WPE pytorch version: Ported from https://github.com/fgnt/nara_wpe
Many functions aren't enough tested"""


[docs]def signal_framing( signal: Union[torch.Tensor, ComplexTensor], frame_length: int, frame_step: int, pad_value=0, ) -> Union[torch.Tensor, ComplexTensor]: """Expands signal into frames of frame_length. Args: signal : (B * F, D, T) Returns: torch.Tensor: (B * F, D, T, W) """ if isinstance(signal, ComplexTensor): real = signal_framing(signal.real, frame_length, frame_step, pad_value) imag = signal_framing(signal.imag, frame_length, frame_step, pad_value) return ComplexTensor(real, imag) elif is_torch_1_9_plus and torch.is_complex(signal): real = signal_framing(signal.real, frame_length, frame_step, pad_value) imag = signal_framing(signal.imag, frame_length, frame_step, pad_value) return torch.complex(real, imag) signal = F.pad(signal, (0, frame_length - 1), "constant", pad_value) indices = sum( [ list(range(i, i + frame_length)) for i in range(0, signal.size(-1) - frame_length + 1, frame_step) ], [], ) signal = signal[..., indices].view(*signal.size()[:-1], -1, frame_length) return signal
[docs]def get_power(signal, dim=-2) -> torch.Tensor: """Calculates power for `signal` Args: signal : Single frequency signal with shape (F, C, T). axis: reduce_mean axis Returns: Power with shape (F, T) """ power = signal.real**2 + signal.imag**2 power = power.mean(dim=dim) return power
[docs]def get_correlations( Y: Union[torch.Tensor, ComplexTensor], inverse_power: torch.Tensor, taps, delay ) -> Tuple[Union[torch.Tensor, ComplexTensor], Union[torch.Tensor, ComplexTensor]]: """Calculates weighted correlations of a window of length taps Args: Y : Complex-valued STFT signal with shape (F, C, T) inverse_power : Weighting factor with shape (F, T) taps (int): Lenghts of correlation window delay (int): Delay for the weighting factor Returns: Correlation matrix of shape (F, taps*C, taps*C) Correlation vector of shape (F, taps, C, C) """ assert inverse_power.dim() == 2, inverse_power.dim() assert inverse_power.size(0) == Y.size(0), (inverse_power.size(0), Y.size(0)) F, C, T = Y.size() # Y: (F, C, T) -> Psi: (F, C, T, taps) Psi = signal_framing(Y, frame_length=taps, frame_step=1)[ ..., : T - delay - taps + 1, : ] # Reverse along taps-axis Psi = reverse(Psi, dim=-1) Psi_conj_norm = Psi.conj() * inverse_power[..., None, delay + taps - 1 :, None] # (F, C, T, taps) x (F, C, T, taps) -> (F, taps, C, taps, C) correlation_matrix = einsum("fdtk,fetl->fkdle", Psi_conj_norm, Psi) # (F, taps, C, taps, C) -> (F, taps * C, taps * C) correlation_matrix = correlation_matrix.reshape(F, taps * C, taps * C) # (F, C, T, taps) x (F, C, T) -> (F, taps, C, C) correlation_vector = einsum( "fdtk,fet->fked", Psi_conj_norm, Y[..., delay + taps - 1 :] ) return correlation_matrix, correlation_vector
[docs]def get_filter_matrix_conj( correlation_matrix: Union[torch.Tensor, ComplexTensor], correlation_vector: Union[torch.Tensor, ComplexTensor], eps: float = 1e-10, ) -> Union[torch.Tensor, ComplexTensor]: """Calculate (conjugate) filter matrix based on correlations for one freq. Args: correlation_matrix : Correlation matrix (F, taps * C, taps * C) correlation_vector : Correlation vector (F, taps, C, C) eps: Returns: filter_matrix_conj (torch.complex/ComplexTensor): (F, taps, C, C) """ F, taps, C, _ = correlation_vector.size() # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2) correlation_vector = ( correlation_vector.permute(0, 2, 1, 3).contiguous().view(F, C, taps * C) ) eye = torch.eye( correlation_matrix.size(-1), dtype=correlation_matrix.dtype, device=correlation_matrix.device, ) shape = ( tuple(1 for _ in range(correlation_matrix.dim() - 2)) + correlation_matrix.shape[-2:] ) eye = eye.view(*shape) correlation_matrix += eps * eye inv_correlation_matrix = correlation_matrix.inverse() # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C) stacked_filter_conj = matmul( correlation_vector, inv_correlation_matrix.transpose(-1, -2) ) # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1) filter_matrix_conj = stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1) return filter_matrix_conj
[docs]def perform_filter_operation( Y: Union[torch.Tensor, ComplexTensor], filter_matrix_conj: Union[torch.Tensor, ComplexTensor], taps, delay, ) -> Union[torch.Tensor, ComplexTensor]: """perform_filter_operation Args: Y : Complex-valued STFT signal of shape (F, C, T) filter Matrix (F, taps, C, C) """ if isinstance(Y, ComplexTensor): complex_module = FC pad_func = FC.pad elif is_torch_1_9_plus and torch.is_complex(Y): complex_module = torch pad_func = F.pad else: raise ValueError( "Please update your PyTorch version to 1.9+ for complex support." ) T = Y.size(-1) # Y_tilde: (taps, F, C, T) Y_tilde = complex_module.stack( [ pad_func(Y[:, :, : T - delay - i], (delay + i, 0), mode="constant", value=0) for i in range(taps) ], dim=0, ) reverb_tail = complex_module.einsum("fpde,pfdt->fet", (filter_matrix_conj, Y_tilde)) return Y - reverb_tail
[docs]def wpe_one_iteration( Y: Union[torch.Tensor, ComplexTensor], power: torch.Tensor, taps: int = 10, delay: int = 3, eps: float = 1e-10, inverse_power: bool = True, ) -> Union[torch.Tensor, ComplexTensor]: """WPE for one iteration Args: Y: Complex valued STFT signal with shape (..., C, T) power: : (..., T) taps: Number of filter taps delay: Delay as a guard interval, such that X does not become zero. eps: inverse_power (bool): Returns: enhanced: (..., C, T) """ assert Y.size()[:-2] == power.size()[:-1] batch_freq_size = Y.size()[:-2] Y = Y.view(-1, *Y.size()[-2:]) power = power.view(-1, power.size()[-1]) if inverse_power: inverse_power = 1 / torch.clamp(power, min=eps) else: inverse_power = power correlation_matrix, correlation_vector = get_correlations( Y, inverse_power, taps, delay ) filter_matrix_conj = get_filter_matrix_conj(correlation_matrix, correlation_vector) enhanced = perform_filter_operation(Y, filter_matrix_conj, taps, delay) enhanced = enhanced.view(*batch_freq_size, *Y.size()[-2:]) return enhanced
[docs]def wpe( Y: Union[torch.Tensor, ComplexTensor], taps=10, delay=3, iterations=3 ) -> Union[torch.Tensor, ComplexTensor]: """WPE Args: Y: Complex valued STFT signal with shape (F, C, T) taps: Number of filter taps delay: Delay as a guard interval, such that X does not become zero. iterations: Returns: enhanced: (F, C, T) """ enhanced = Y for _ in range(iterations): power = get_power(enhanced) enhanced = wpe_one_iteration(Y, power, taps=taps, delay=delay) return enhanced