Source code for espnet2.diar.attractor.rnn_attractor

import torch

from espnet2.diar.attractor.abs_attractor import AbsAttractor


[docs]class RnnAttractor(AbsAttractor): """encoder decoder attractor for speaker diarization""" def __init__( self, encoder_output_size: int, layer: int = 1, unit: int = 512, dropout: float = 0.1, attractor_grad: bool = True, ): super().__init__() self.attractor_encoder = torch.nn.LSTM( input_size=encoder_output_size, hidden_size=unit, num_layers=layer, dropout=dropout, batch_first=True, ) self.attractor_decoder = torch.nn.LSTM( input_size=encoder_output_size, hidden_size=unit, num_layers=layer, dropout=dropout, batch_first=True, ) self.dropout_layer = torch.nn.Dropout(p=dropout) self.linear_projection = torch.nn.Linear(unit, 1) self.attractor_grad = attractor_grad
[docs] def forward( self, enc_input: torch.Tensor, ilens: torch.Tensor, dec_input: torch.Tensor, ): """Forward. Args: enc_input (torch.Tensor): hidden_space [Batch, T, F] ilens (torch.Tensor): input lengths [Batch] dec_input (torch.Tensor): decoder input (zeros) [Batch, num_spk + 1, F] Returns: attractor: [Batch, num_spk + 1, F] att_prob: [Batch, num_spk + 1, 1] """ pack = torch.nn.utils.rnn.pack_padded_sequence( enc_input, lengths=ilens.cpu(), batch_first=True, enforce_sorted=False ) _, hs = self.attractor_encoder(pack) attractor, _ = self.attractor_decoder(dec_input, hs) attractor = self.dropout_layer(attractor) if self.attractor_grad is True: att_prob = self.linear_projection(attractor) else: att_prob = self.linear_projection(attractor.detach()) return attractor, att_prob