Source code for espnet2.spk.projector.rawnet3_projector

import torch

from espnet2.spk.projector.abs_projector import AbsProjector


[docs]class RawNet3Projector(AbsProjector): def __init__(self, input_size, output_size): super().__init__() self._output_size = output_size self.bn = torch.nn.BatchNorm1d(input_size) self.fc = torch.nn.Linear(input_size, output_size)
[docs] def output_size(self): return self._output_size
[docs] def forward(self, x): return self.fc(self.bn(x))