# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""ESPnetModel abstract class for GAN-based training."""
from abc import ABC, abstractmethod
from typing import Dict, Union
import torch
from espnet2.train.abs_espnet_model import AbsESPnetModel
[docs]class AbsGANESPnetModel(AbsESPnetModel, torch.nn.Module, ABC):
"""The common abstract class among each GAN-based task.
"ESPnetModel" is referred to a class which inherits torch.nn.Module,
and makes the dnn-models "forward" as its member field, a.k.a delegate
pattern. And "forward" must accept the argument "forward_generator" and
Return the dict of "loss", "stats", "weight", and "optim_idx".
"optim_idx" for generator must be 0 and that for discriminator must be 1.
Example:
>>> from espnet2.tasks.abs_task import AbsTask
>>> class YourESPnetModel(AbsGANESPnetModel):
... def forward(self, input, input_lengths, forward_generator=True):
... ...
... if forward_generator:
... # return loss for the generator
... # optim idx 0 indicates generator optimizer
... return dict(loss=loss, stats=stats, weight=weight, optim_idx=0)
... else:
... # return loss for the discriminator
... # optim idx 1 indicates discriminator optimizer
... return dict(loss=loss, stats=stats, weight=weight, optim_idx=1)
>>> class YourTask(AbsTask):
... @classmethod
... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
"""
[docs] @abstractmethod
def forward(
self,
forward_generator: bool = True,
**batch: torch.Tensor,
) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor], int]]:
"""Return the generator loss or the discrimiantor loss.
This method must have an argument "forward_generator" to switch the generator
loss calculation and the discrimiantor loss calculation. If forward_generator
is true, return the generator loss with optim_idx 0. If forward_generator is
false, return the discrimiantor loss with optim_idx 1.
Args:
forward_generator (bool): Whether to return the generator loss or the
discrimiantor loss. This must have the default value.
Returns:
Dict[str, Any]:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
* weight (Tensor): Weight tensor to summarize losses.
* optim_idx (int): Optimizer index (0 for G and 1 for D).
"""
raise NotImplementedError
[docs] @abstractmethod
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
raise NotImplementedError