from abc import ABC, abstractmethod
import torch
EPS = torch.finfo(torch.get_default_dtype()).eps
[docs]class AbsASVSpoofLoss(torch.nn.Module, ABC):
"""Base class for all ASV Spoofing loss modules."""
# the name will be the key that appears in the reporter
@property
def name(self) -> str:
return NotImplementedError
[docs] @abstractmethod
def forward(
self,
ref,
inf,
) -> torch.Tensor:
# the return tensor should be shape of (batch)
raise NotImplementedError
[docs] @abstractmethod
def score(
self,
pred,
) -> torch.Tensor:
raise NotImplemented