Source code for espnet2.enh.layers.dnn_beamformer
"""DNN beamformer module."""
import logging
from typing import List, Optional, Tuple, Union
import torch
from packaging.version import parse as V
from torch.nn import functional as F
from torch_complex.tensor import ComplexTensor
import espnet2.enh.layers.beamformer as bf_v1
import espnet2.enh.layers.beamformer_th as bf_v2
from espnet2.enh.layers.complex_utils import stack, to_double, to_float
from espnet2.enh.layers.mask_estimator import MaskEstimator
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
is_torch_1_12_1_plus = V(torch.__version__) >= V("1.12.1")
BEAMFORMER_TYPES = (
# Minimum Variance Distortionless Response beamformer
"mvdr", # RTF-based formula
"mvdr_souden", # Souden's solution
# Minimum Power Distortionless Response beamformer
"mpdr", # RTF-based formula
"mpdr_souden", # Souden's solution
# weighted MPDR beamformer
"wmpdr", # RTF-based formula
"wmpdr_souden", # Souden's solution
# Weighted Power minimization Distortionless response beamformer
"wpd", # RTF-based formula
"wpd_souden", # Souden's solution
# Multi-channel Wiener Filter (MWF) and weighted MWF
"mwf",
"wmwf",
# Speech Distortion Weighted (SDW) MWF
"sdw_mwf",
# Rank-1 MWF
"r1mwf",
# Linearly Constrained Minimum Variance beamformer
"lcmv",
# Linearly Constrained Minimum Power beamformer
"lcmp",
# weighted Linearly Constrained Minimum Power beamformer
"wlcmp",
# Generalized Eigenvalue beamformer
"gev",
"gev_ban", # with blind analytic normalization (BAN) post-filtering
# time-frequency-bin-wise switching (TFS) MVDR beamformer
"mvdr_tfs",
"mvdr_tfs_souden",
)
[docs]class DNN_Beamformer(torch.nn.Module):
"""DNN mask based Beamformer.
Citation:
Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
http://proceedings.mlr.press/v70/ochiai17a/ochiai17a.pdf
"""
def __init__(
self,
bidim,
btype: str = "blstmp",
blayers: int = 3,
bunits: int = 300,
bprojs: int = 320,
num_spk: int = 1,
use_noise_mask: bool = True,
nonlinear: str = "sigmoid",
dropout_rate: float = 0.0,
badim: int = 320,
ref_channel: int = -1,
beamformer_type: str = "mvdr_souden",
rtf_iterations: int = 2,
mwf_mu: float = 1.0,
eps: float = 1e-6,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
mask_flooring: bool = False,
flooring_thres: float = 1e-6,
use_torch_solver: bool = True,
# False to use old APIs; True to use torchaudio-based new APIs
use_torchaudio_api: bool = False,
# only for WPD beamformer
btaps: int = 5,
bdelay: int = 3,
):
super().__init__()
bnmask = num_spk + 1 if use_noise_mask else num_spk
self.mask = MaskEstimator(
btype,
bidim,
blayers,
bunits,
bprojs,
dropout_rate,
nmask=bnmask,
nonlinear=nonlinear,
)
self.ref = (
AttentionReference(bidim, badim, eps=eps) if ref_channel < 0 else None
)
self.ref_channel = ref_channel
self.use_noise_mask = use_noise_mask
assert num_spk >= 1, num_spk
self.num_spk = num_spk
self.nmask = bnmask
if beamformer_type not in BEAMFORMER_TYPES:
raise ValueError("Not supporting beamformer_type=%s" % beamformer_type)
if (
beamformer_type == "mvdr_souden" or not beamformer_type.endswith("_souden")
) and not use_noise_mask:
if num_spk == 1:
logging.warning(
"Initializing %s beamformer without noise mask "
"estimator (single-speaker case)" % beamformer_type.upper()
)
logging.warning(
"(1 - speech_mask) will be used for estimating noise "
"PSD in %s beamformer!" % beamformer_type.upper()
)
else:
logging.warning(
"Initializing %s beamformer without noise mask "
"estimator (multi-speaker case)" % beamformer_type.upper()
)
logging.warning(
"Interference speech masks will be used for estimating "
"noise PSD in %s beamformer!" % beamformer_type.upper()
)
self.beamformer_type = beamformer_type
if not beamformer_type.endswith("_souden"):
assert rtf_iterations >= 2, rtf_iterations
# number of iterations in power method for estimating the RTF
self.rtf_iterations = rtf_iterations
# noise suppression weight in SDW-MWF
self.mwf_mu = mwf_mu
assert btaps >= 0 and bdelay >= 0, (btaps, bdelay)
self.btaps = btaps
self.bdelay = bdelay if self.btaps > 0 else 1
self.eps = eps
self.diagonal_loading = diagonal_loading
self.diag_eps = diag_eps
self.mask_flooring = mask_flooring
self.flooring_thres = flooring_thres
self.use_torch_solver = use_torch_solver
if not use_torch_solver:
logging.warning(
"The `use_torch_solver` argument has been deprecated. "
"Now it will always be true in DNN_Beamformer"
)
if use_torchaudio_api and is_torch_1_12_1_plus:
self.bf_func = bf_v2
else:
self.bf_func = bf_v1
[docs] def forward(
self,
data: Union[torch.Tensor, ComplexTensor],
ilens: torch.LongTensor,
powers: Optional[List[torch.Tensor]] = None,
oracle_masks: Optional[List[torch.Tensor]] = None,
) -> Tuple[Union[torch.Tensor, ComplexTensor], torch.LongTensor, torch.Tensor]:
"""DNN_Beamformer forward function.
Notation:
B: Batch
C: Channel
T: Time or Sequence length
F: Freq
Args:
data (torch.complex64/ComplexTensor): (B, T, C, F)
ilens (torch.Tensor): (B,)
powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T)
oracle_masks (List[torch.Tensor] or None): oracle masks (B, F, C, T)
if not None, oracle_masks will be used instead of self.mask
Returns:
enhanced (torch.complex64/ComplexTensor): (B, T, F)
ilens (torch.Tensor): (B,)
masks (torch.Tensor): (B, T, C, F)
"""
# data (B, T, C, F) -> (B, F, C, T)
data = data.permute(0, 3, 2, 1)
data_d = to_double(data)
# mask: [(B, F, C, T)]
if oracle_masks is not None:
masks = oracle_masks
else:
masks, _ = self.mask(data, ilens)
assert self.nmask == len(masks), len(masks)
# floor masks to increase numerical stability
if self.mask_flooring:
masks = [torch.clamp(m, min=self.flooring_thres) for m in masks]
if self.num_spk == 1: # single-speaker case
if self.use_noise_mask:
# (mask_speech, mask_noise)
mask_speech, mask_noise = masks
else:
# (mask_speech,)
mask_speech = masks[0]
mask_noise = 1 - mask_speech
if self.beamformer_type in ("lcmv", "lcmp", "wlcmp"):
raise NotImplementedError("Single source is not supported yet")
beamformer_stats = self.bf_func.prepare_beamformer_stats(
data_d,
[mask_speech],
mask_noise,
powers=powers,
beamformer_type=self.beamformer_type,
bdelay=self.bdelay,
btaps=self.btaps,
eps=self.eps,
)
if self.beamformer_type in ("mvdr", "mpdr", "wmpdr", "wpd"):
enhanced, ws = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"],
beamformer_stats["psd_speech"],
psd_distortion=beamformer_stats["psd_distortion"],
)
elif (
self.beamformer_type.endswith("_souden")
or self.beamformer_type == "mwf"
or self.beamformer_type == "wmwf"
or self.beamformer_type == "sdw_mwf"
or self.beamformer_type == "r1mwf"
or self.beamformer_type.startswith("gev")
):
enhanced, ws = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"],
beamformer_stats["psd_speech"],
)
else:
raise ValueError(
"Not supporting beamformer_type={}".format(self.beamformer_type)
)
# (..., F, T) -> (..., T, F)
enhanced = enhanced.transpose(-1, -2)
else: # multi-speaker case
if self.use_noise_mask:
# (mask_speech1, ..., mask_noise)
mask_speech = list(masks[:-1])
mask_noise = masks[-1]
else:
# (mask_speech1, ..., mask_speechX)
mask_speech = list(masks)
mask_noise = None
beamformer_stats = self.bf_func.prepare_beamformer_stats(
data_d,
mask_speech,
mask_noise,
powers=powers,
beamformer_type=self.beamformer_type,
bdelay=self.bdelay,
btaps=self.btaps,
eps=self.eps,
)
if self.beamformer_type in ("lcmv", "lcmp", "wlcmp"):
rtf_mat = self.bf_func.get_rtf_matrix(
beamformer_stats["psd_speech"],
beamformer_stats["psd_distortion"],
diagonal_loading=self.diagonal_loading,
ref_channel=self.ref_channel,
rtf_iterations=self.rtf_iterations,
diag_eps=self.diag_eps,
)
enhanced, ws = [], []
for i in range(self.num_spk):
# treat all other speakers' psd_speech as noises
if self.beamformer_type in ("mvdr", "mvdr_tfs", "wmpdr", "wpd"):
enh, w = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"][i],
beamformer_stats["psd_speech"][i],
psd_distortion=beamformer_stats["psd_distortion"][i],
)
elif self.beamformer_type in (
"mvdr_souden",
"mvdr_tfs_souden",
"wmpdr_souden",
"wpd_souden",
"wmwf",
"sdw_mwf",
"r1mwf",
"gev",
"gev_ban",
):
enh, w = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"][i],
beamformer_stats["psd_speech"][i],
)
elif self.beamformer_type == "mpdr":
enh, w = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"],
beamformer_stats["psd_speech"][i],
psd_distortion=beamformer_stats["psd_distortion"][i],
)
elif self.beamformer_type in ("mpdr_souden", "mwf"):
enh, w = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"],
beamformer_stats["psd_speech"][i],
)
elif self.beamformer_type == "lcmp":
enh, w = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"],
beamformer_stats["psd_speech"][i],
rtf_mat=rtf_mat,
spk=i,
)
elif self.beamformer_type in ("lcmv", "wlcmp"):
enh, w = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"][i],
beamformer_stats["psd_speech"][i],
rtf_mat=rtf_mat,
spk=i,
)
else:
raise ValueError(
"Not supporting beamformer_type={}".format(self.beamformer_type)
)
# (..., F, T) -> (..., T, F)
enh = enh.transpose(-1, -2)
enhanced.append(enh)
ws.append(w)
# (..., F, C, T) -> (..., T, C, F)
masks = [m.transpose(-1, -3) for m in masks]
return enhanced, ilens, masks
[docs] def apply_beamforming(
self,
data,
ilens,
psd_n,
psd_speech,
psd_distortion=None,
rtf_mat=None,
spk=0,
):
"""Beamforming with the provided statistics.
Args:
data (torch.complex64/ComplexTensor): (B, F, C, T)
ilens (torch.Tensor): (B,)
psd_n (torch.complex64/ComplexTensor):
Noise covariance matrix for MVDR (B, F, C, C)
Observation covariance matrix for MPDR/wMPDR (B, F, C, C)
Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C)
psd_speech (torch.complex64/ComplexTensor):
Speech covariance matrix (B, F, C, C)
psd_distortion (torch.complex64/ComplexTensor):
Noise covariance matrix (B, F, C, C)
rtf_mat (torch.complex64/ComplexTensor):
RTF matrix (B, F, C, num_spk)
spk (int): speaker index
Return:
enhanced (torch.complex64/ComplexTensor): (B, F, T)
ws (torch.complex64/ComplexTensor): (B, F) or (B, F, (btaps+1)*C)
"""
# u: (B, C)
if self.ref_channel < 0:
u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens)
u = u.double()
else:
if self.beamformer_type.endswith("_souden"):
# (optional) Create onehot vector for fixed reference microphone
u = torch.zeros(
*(data.size()[:-3] + (data.size(-2),)),
device=data.device,
dtype=torch.double
)
u[..., self.ref_channel].fill_(1)
else:
# for simplifying computation in RTF-based beamforming
u = self.ref_channel
if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"):
ws = self.bf_func.get_mvdr_vector_with_rtf(
to_double(psd_n),
to_double(psd_speech),
to_double(psd_distortion),
iterations=self.rtf_iterations,
reference_vector=u,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = self.bf_func.apply_beamforming_vector(ws, to_double(data))
elif self.beamformer_type == "mvdr_tfs":
assert isinstance(psd_n, (list, tuple))
ws = [
self.bf_func.get_mvdr_vector_with_rtf(
to_double(psd_n_i),
to_double(psd_speech),
to_double(psd_distortion),
iterations=self.rtf_iterations,
reference_vector=u,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
for psd_n_i in psd_n
]
enhanced = stack(
[self.bf_func.apply_beamforming_vector(w, to_double(data)) for w in ws]
)
with torch.no_grad():
index = enhanced.abs().argmin(dim=0, keepdims=True)
enhanced = enhanced.gather(0, index).squeeze(0)
ws = stack(ws, dim=0)
elif self.beamformer_type in (
"mpdr_souden",
"mvdr_souden",
"wmpdr_souden",
):
ws = self.bf_func.get_mvdr_vector(
to_double(psd_speech),
to_double(psd_n),
u,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = self.bf_func.apply_beamforming_vector(ws, to_double(data))
elif self.beamformer_type == "mvdr_tfs_souden":
assert isinstance(psd_n, (list, tuple))
ws = [
self.bf_func.get_mvdr_vector(
to_double(psd_speech),
to_double(psd_n_i),
u,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
for psd_n_i in psd_n
]
enhanced = stack(
[self.bf_func.apply_beamforming_vector(w, to_double(data)) for w in ws]
)
with torch.no_grad():
index = enhanced.abs().argmin(dim=0, keepdims=True)
enhanced = enhanced.gather(0, index).squeeze(0)
ws = stack(ws, dim=0)
elif self.beamformer_type == "wpd":
ws = self.bf_func.get_WPD_filter_with_rtf(
to_double(psd_n),
to_double(psd_speech),
to_double(psd_distortion),
iterations=self.rtf_iterations,
reference_vector=u,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = self.bf_func.perform_WPD_filtering(
ws, to_double(data), self.bdelay, self.btaps
)
elif self.beamformer_type == "wpd_souden":
ws = self.bf_func.get_WPD_filter_v2(
to_double(psd_speech),
to_double(psd_n),
u,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = self.bf_func.perform_WPD_filtering(
ws, to_double(data), self.bdelay, self.btaps
)
elif self.beamformer_type in ("mwf", "wmwf"):
ws = self.bf_func.get_mwf_vector(
to_double(psd_speech),
to_double(psd_n),
u,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = self.bf_func.apply_beamforming_vector(ws, to_double(data))
elif self.beamformer_type == "sdw_mwf":
ws = self.bf_func.get_sdw_mwf_vector(
to_double(psd_speech),
to_double(psd_n),
u,
denoising_weight=self.mwf_mu,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = self.bf_func.apply_beamforming_vector(ws, to_double(data))
elif self.beamformer_type == "r1mwf":
ws = self.bf_func.get_rank1_mwf_vector(
to_double(psd_speech),
to_double(psd_n),
u,
denoising_weight=self.mwf_mu,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = self.bf_func.apply_beamforming_vector(ws, to_double(data))
elif self.beamformer_type in ("lcmp", "wlcmp", "lcmv"):
ws = self.bf_func.get_lcmv_vector_with_rtf(
to_double(psd_n),
to_double(rtf_mat),
reference_vector=spk,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = self.bf_func.apply_beamforming_vector(ws, to_double(data))
elif self.beamformer_type.startswith("gev"):
ws = self.bf_func.get_gev_vector(
to_double(psd_n),
to_double(psd_speech),
mode="power",
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = self.bf_func.apply_beamforming_vector(ws, to_double(data))
if self.beamformer_type == "gev_ban":
gain = self.bf_func.blind_analytic_normalization(ws, to_double(psd_n))
enhanced = enhanced * gain.unsqueeze(-1)
else:
raise ValueError(
"Not supporting beamformer_type={}".format(self.beamformer_type)
)
return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)
[docs] def predict_mask(
self, data: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
"""Predict masks for beamforming.
Args:
data (torch.complex64/ComplexTensor): (B, T, C, F), double precision
ilens (torch.Tensor): (B,)
Returns:
masks (torch.Tensor): (B, T, C, F)
ilens (torch.Tensor): (B,)
"""
masks, _ = self.mask(to_float(data.permute(0, 3, 2, 1)), ilens)
# (B, F, C, T) -> (B, T, C, F)
masks = [m.transpose(-1, -3) for m in masks]
return masks, ilens
[docs]class AttentionReference(torch.nn.Module):
def __init__(self, bidim, att_dim, eps=1e-6):
super().__init__()
self.mlp_psd = torch.nn.Linear(bidim, att_dim)
self.gvec = torch.nn.Linear(att_dim, 1)
self.eps = eps
[docs] def forward(
self,
psd_in: Union[torch.Tensor, ComplexTensor],
ilens: torch.LongTensor,
scaling: float = 2.0,
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""Attention-based reference forward function.
Args:
psd_in (torch.complex64/ComplexTensor): (B, F, C, C)
ilens (torch.Tensor): (B,)
scaling (float):
Returns:
u (torch.Tensor): (B, C)
ilens (torch.Tensor): (B,)
"""
B, _, C = psd_in.size()[:3]
assert psd_in.size(2) == psd_in.size(3), psd_in.size()
# psd_in: (B, F, C, C)
psd = psd_in.masked_fill(
torch.eye(C, dtype=torch.bool, device=psd_in.device).type(torch.bool), 0
)
# psd: (B, F, C, C) -> (B, C, F)
psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
# Calculate amplitude
psd_feat = (psd.real**2 + psd.imag**2 + self.eps) ** 0.5
# (B, C, F) -> (B, C, F2)
mlp_psd = self.mlp_psd(psd_feat)
# (B, C, F2) -> (B, C, 1) -> (B, C)
e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
u = F.softmax(scaling * e, dim=-1)
return u, ilens