Source code for espnet2.enh.loss.wrappers.abs_wrapper

from abc import ABC, abstractmethod
from typing import Dict, List, Tuple

import torch


[docs]class AbsLossWrapper(torch.nn.Module, ABC): """Base class for all Enhancement loss wrapper modules.""" # The weight for the current loss in the multi-task learning. # The overall training target will be combined as: # loss = weight_1 * loss_1 + ... + weight_N * loss_N weight = 1.0
[docs] @abstractmethod def forward( self, ref: List, inf: List, others: Dict, ) -> Tuple[torch.Tensor, Dict, Dict]: raise NotImplementedError