Source code for espnet2.train.gan_trainer

# Copyright 2021 Tomoki Hayashi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Trainer module for GAN-based training."""

import argparse
import dataclasses
import logging
import time
from contextlib import contextmanager
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import torch
from packaging.version import parse as V
from typeguard import check_argument_types

from espnet2.schedulers.abs_scheduler import AbsBatchStepScheduler, AbsScheduler
from espnet2.torch_utils.device_funcs import to_device
from espnet2.torch_utils.recursive_op import recursive_average
from espnet2.train.distributed_utils import DistributedOption
from espnet2.train.reporter import SubReporter
from espnet2.train.trainer import Trainer, TrainerOptions
from espnet2.utils.build_dataclass import build_dataclass
from espnet2.utils.types import str2bool

if torch.distributed.is_available():
    from torch.distributed import ReduceOp

if V(torch.__version__) >= V("1.6.0"):
    from torch.cuda.amp import GradScaler, autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):  # NOQA
        yield

    GradScaler = None

try:
    import fairscale
except ImportError:
    fairscale = None


[docs]@dataclasses.dataclass class GANTrainerOptions(TrainerOptions): """Trainer option dataclass for GANTrainer.""" generator_first: bool
[docs]class GANTrainer(Trainer): """Trainer for GAN-based training. If you'd like to use this trainer, the model must inherit espnet.train.abs_gan_espnet_model.AbsGANESPnetModel. """
[docs] @classmethod def build_options(cls, args: argparse.Namespace) -> TrainerOptions: """Build options consumed by train(), eval(), and plot_attention().""" assert check_argument_types() return build_dataclass(GANTrainerOptions, args)
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser): """Add additional arguments for GAN-trainer.""" parser.add_argument( "--generator_first", type=str2bool, default=False, help="Whether to update generator first.", )
[docs] @classmethod def train_one_epoch( cls, model: torch.nn.Module, iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], scaler: Optional[GradScaler], reporter: SubReporter, summary_writer, options: GANTrainerOptions, distributed_option: DistributedOption, ) -> bool: """Train one epoch.""" assert check_argument_types() grad_noise = options.grad_noise accum_grad = options.accum_grad grad_clip = options.grad_clip grad_clip_type = options.grad_clip_type log_interval = options.log_interval no_forward_run = options.no_forward_run ngpu = options.ngpu use_wandb = options.use_wandb generator_first = options.generator_first distributed = distributed_option.distributed # Check unavailable options # TODO(kan-bayashi): Support the use of these options if accum_grad > 1: raise NotImplementedError( "accum_grad > 1 is not supported in GAN-based training." ) if grad_noise: raise NotImplementedError( "grad_noise is not supported in GAN-based training." ) if log_interval is None: try: log_interval = max(len(iterator) // 20, 10) except TypeError: log_interval = 100 model.train() all_steps_are_invalid = True # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") start_time = time.perf_counter() for iiter, (_, batch) in enumerate( reporter.measure_iter_time(iterator, "iter_time"), 1 ): assert isinstance(batch, dict), type(batch) if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: break batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: all_steps_are_invalid = False continue turn_start_time = time.perf_counter() if generator_first: turns = ["generator", "discriminator"] else: turns = ["discriminator", "generator"] for turn in turns: with autocast(scaler is not None): with reporter.measure_time(f"{turn}_forward_time"): retval = model(forward_generator=turn == "generator", **batch) # Note(kamo): # Supporting two patterns for the returned value from the model # a. dict type if isinstance(retval, dict): loss = retval["loss"] stats = retval["stats"] weight = retval["weight"] optim_idx = retval.get("optim_idx") if optim_idx is not None and not isinstance(optim_idx, int): if not isinstance(optim_idx, torch.Tensor): raise RuntimeError( "optim_idx must be int or 1dim torch.Tensor, " f"but got {type(optim_idx)}" ) if optim_idx.dim() >= 2: raise RuntimeError( "optim_idx must be int or 1dim torch.Tensor, " f"but got {optim_idx.dim()}dim tensor" ) if optim_idx.dim() == 1: for v in optim_idx: if v != optim_idx[0]: raise RuntimeError( "optim_idx must be 1dim tensor " "having same values for all entries" ) optim_idx = optim_idx[0].item() else: optim_idx = optim_idx.item() # b. tuple or list type else: raise RuntimeError("model output must be dict.") stats = {k: v for k, v in stats.items() if v is not None} if ngpu > 1 or distributed: # Apply weighted averaging for loss and stats loss = (loss * weight.type(loss.dtype)).sum() # if distributed, this method can also apply all_reduce() stats, weight = recursive_average(stats, weight, distributed) # Now weight is summation over all workers loss /= weight if distributed: # NOTE(kamo): Multiply world_size since DistributedDataParallel # automatically normalizes the gradient by world_size. loss *= torch.distributed.get_world_size() reporter.register(stats, weight) with reporter.measure_time(f"{turn}_backward_time"): if scaler is not None: # Scales loss. Calls backward() on scaled loss # to create scaled gradients. # Backward passes under autocast are not recommended. # Backward ops run in the same dtype autocast chose # for corresponding forward ops. scaler.scale(loss).backward() else: loss.backward() if scaler is not None: # Unscales the gradients of optimizer's assigned params in-place for iopt, optimizer in enumerate(optimizers): if optim_idx is not None and iopt != optim_idx: continue scaler.unscale_(optimizer) # TODO(kan-bayashi): Compute grad norm without clipping grad_norm = None if grad_clip > 0.0: # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=grad_clip, norm_type=grad_clip_type, ) # PyTorch<=1.4, clip_grad_norm_ returns float value if not isinstance(grad_norm, torch.Tensor): grad_norm = torch.tensor(grad_norm) if grad_norm is None or torch.isfinite(grad_norm): all_steps_are_invalid = False with reporter.measure_time(f"{turn}_optim_step_time"): for iopt, (optimizer, scheduler) in enumerate( zip(optimizers, schedulers) ): if optim_idx is not None and iopt != optim_idx: continue if scaler is not None: # scaler.step() first unscales the gradients of # the optimizer's assigned params. scaler.step(optimizer) # Updates the scale for next iteration. scaler.update() else: optimizer.step() if isinstance(scheduler, AbsBatchStepScheduler): scheduler.step() else: logging.warning( f"The grad norm is {grad_norm}. " "Skipping updating the model." ) # Must invoke scaler.update() if unscale_() is used in the # iteration to avoid the following error: # RuntimeError: unscale_() has already been called # on this optimizer since the last update(). # Note that if the gradient has inf/nan values, # scaler.step skips optimizer.step(). if scaler is not None: for iopt, optimizer in enumerate(optimizers): if optim_idx is not None and iopt != optim_idx: continue scaler.step(optimizer) scaler.update() for iopt, optimizer in enumerate(optimizers): # NOTE(kan-bayashi): In the case of GAN, we need to clear # the gradient of both optimizers after every update. optimizer.zero_grad() # Register lr and train/load time[sec/step], # where step refers to accum_grad * mini-batch reporter.register( { f"optim{optim_idx}_lr{i}": pg["lr"] for i, pg in enumerate(optimizers[optim_idx].param_groups) if "lr" in pg }, ) reporter.register( {f"{turn}_train_time": time.perf_counter() - turn_start_time} ) turn_start_time = time.perf_counter() reporter.register({"train_time": time.perf_counter() - start_time}) start_time = time.perf_counter() # NOTE(kamo): Call log_message() after next() reporter.next() if iiter % log_interval == 0: logging.info(reporter.log_message(-log_interval)) if summary_writer is not None: reporter.tensorboard_add_scalar(summary_writer, -log_interval) if use_wandb: reporter.wandb_log() else: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) return all_steps_are_invalid
[docs] @classmethod @torch.no_grad() def validate_one_epoch( cls, model: torch.nn.Module, iterator: Iterable[Dict[str, torch.Tensor]], reporter: SubReporter, options: GANTrainerOptions, distributed_option: DistributedOption, ) -> None: """Validate one epoch.""" assert check_argument_types() ngpu = options.ngpu no_forward_run = options.no_forward_run distributed = distributed_option.distributed generator_first = options.generator_first model.eval() # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") for _, batch in iterator: assert isinstance(batch, dict), type(batch) if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: break batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: continue if generator_first: turns = ["generator", "discriminator"] else: turns = ["discriminator", "generator"] for turn in turns: retval = model(forward_generator=turn == "generator", **batch) if isinstance(retval, dict): stats = retval["stats"] weight = retval["weight"] else: _, stats, weight = retval if ngpu > 1 or distributed: # Apply weighted averaging for stats. # if distributed, this method can also apply all_reduce() stats, weight = recursive_average(stats, weight, distributed) reporter.register(stats, weight) reporter.next() else: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)