Source code for espnet2.diar.decoder.linear_decoder

import torch

from espnet2.diar.decoder.abs_decoder import AbsDecoder


[docs]class LinearDecoder(AbsDecoder): """Linear decoder for speaker diarization""" def __init__( self, encoder_output_size: int, num_spk: int = 2, ): super().__init__() self._num_spk = num_spk self.linear_decoder = torch.nn.Linear(encoder_output_size, num_spk)
[docs] def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): hidden_space [Batch, T, F] ilens (torch.Tensor): input lengths [Batch] """ output = self.linear_decoder(input) return output
@property def num_spk(self): return self._num_spk