import torch
from espnet2.diar.decoder.abs_decoder import AbsDecoder
[docs]class LinearDecoder(AbsDecoder):
"""Linear decoder for speaker diarization"""
def __init__(
self,
encoder_output_size: int,
num_spk: int = 2,
):
super().__init__()
self._num_spk = num_spk
self.linear_decoder = torch.nn.Linear(encoder_output_size, num_spk)
[docs] def forward(self, input: torch.Tensor, ilens: torch.Tensor):
"""Forward.
Args:
input (torch.Tensor): hidden_space [Batch, T, F]
ilens (torch.Tensor): input lengths [Batch]
"""
output = self.linear_decoder(input)
return output
@property
def num_spk(self):
return self._num_spk