Source code for espnet2.spk.pooling.chn_attn_stat_pooling

import torch
import torch.nn as nn

from espnet2.spk.pooling.abs_pooling import AbsPooling


[docs]class ChnAttnStatPooling(AbsPooling): """ Aggregates frame-level features to single utterance-level feature. Proposed in B.Desplanques et al., "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification" args: input_size: dimensionality of the input frame-level embeddings. Determined by encoder hyperparameter. For this pooling layer, the output dimensionality will be double of the input_size """ def __init__(self, input_size: int = 1536): super().__init__() self.attention = nn.Sequential( nn.Conv1d(input_size * 3, 128, kernel_size=1), nn.ReLU(), nn.BatchNorm1d(128), nn.Conv1d(128, input_size, kernel_size=1), nn.Softmax(dim=2), )
[docs] def forward(self, x): t = x.size()[-1] global_x = torch.cat( ( x, torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t), torch.sqrt( torch.var(x, dim=2, keepdim=True).clamp(min=1e-4, max=1e4) ).repeat(1, 1, t), ), dim=1, ) w = self.attention(global_x) mu = torch.sum(x * w, dim=2) sg = torch.sqrt( (torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4) ) x = torch.cat((mu, sg), dim=1) return x