Source code for espnet2.asvspoof.loss.am_softmax_loss

import torch

from espnet2.asvspoof.loss.abs_loss import AbsASVSpoofLoss
from espnet.nets.pytorch_backend.nets_utils import to_device


[docs]class ASVSpoofAMSoftmaxLoss(AbsASVSpoofLoss): """Binary loss for ASV Spoofing.""" def __init__( self, weight: float = 1.0, enc_dim: int = 128, s: float = 20, m: float = 0.5, ): super(ASVSpoofAMSoftmaxLoss).__init__() self.weight = weight self.enc_dim = enc_dim self.s = s self.m = m self.centers = torch.nn.Parameter(torch.randn(2, enc_dim)) self.sigmoid = torch.nn.Sigmoid() self.loss = torch.nn.BCELoss(reduction="mean")
[docs] def forward(self, label: torch.Tensor, emb: torch.Tensor, **kwargs): """Forward. Args: label (torch.Tensor): ground truth label [Batch, 1] emb (torch.Tensor): encoder embedding output [Batch, T, enc_dim] """ batch_size = emb.shape[0] emb = torch.mean(emb, dim=1) norms = torch.norm(emb, p=2, dim=-1, keepdim=True) nfeat = torch.div(emb, norms) norms_c = torch.norm(self.centers, p=2, dim=-1, keepdim=True) ncenters = torch.div(self.centers, norms_c) logits = torch.matmul(nfeat, torch.transpose(ncenters, 0, 1)) y_onehot = torch.FloatTensor(batch_size, 2) y_onehot.zero_() y_onehot = torch.autograd.Variable(y_onehot) y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.m) margin_logits = self.s * (logits - y_onehot) loss = self.loss(self.sigmoid(margin_logits[:, 0]), label.view(-1).float()) return loss
[docs] def score(self, emb: torch.Tensor): """Prediction. Args: emb (torch.Tensor): encoder embedding output [Batch, T, enc_dim] """ emb = torch.mean(emb, dim=1) norms = torch.norm(emb, p=2, dim=-1, keepdim=True) nfeat = torch.div(emb, norms) norms_c = torch.norm(self.centers, p=2, dim=-1, keepdim=True) ncenters = torch.div(self.centers, norms_c) logits = torch.matmul(nfeat, torch.transpose(ncenters, 0, 1)) return logits[:, 0]