Source code for espnet2.tasks.uasr

import argparse
import logging
from typing import Callable, Collection, Dict, List, Optional, Tuple

import numpy as np
import torch
from typeguard import check_argument_types, check_return_type

from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.frontend.default import DefaultFrontend
from espnet2.asr.frontend.fused import FusedFrontends
from espnet2.asr.frontend.s3prl import S3prlFrontend
from espnet2.asr.frontend.windowing import SlidingWindow
from espnet2.tasks.abs_task import AbsTask, optim_classes
from espnet2.torch_utils.initialize import initialize
from espnet2.train.class_choices import ClassChoices
from espnet2.train.collate_fn import CommonCollateFn
from espnet2.train.preprocessor import CommonPreprocessor
from espnet2.train.uasr_trainer import UASRTrainer
from espnet2.uasr.discriminator.abs_discriminator import AbsDiscriminator
from espnet2.uasr.discriminator.conv_discriminator import ConvDiscriminator
from espnet2.uasr.espnet_model import ESPnetUASRModel
from espnet2.uasr.generator.abs_generator import AbsGenerator
from espnet2.uasr.generator.conv_generator import ConvGenerator
from espnet2.uasr.loss.abs_loss import AbsUASRLoss
from espnet2.uasr.loss.discriminator_loss import UASRDiscriminatorLoss
from espnet2.uasr.loss.gradient_penalty import UASRGradientPenalty
from espnet2.uasr.loss.phoneme_diversity_loss import UASRPhonemeDiversityLoss
from espnet2.uasr.loss.pseudo_label_loss import UASRPseudoLabelLoss
from espnet2.uasr.loss.smoothness_penalty import UASRSmoothnessPenalty
from espnet2.uasr.segmenter.abs_segmenter import AbsSegmenter
from espnet2.uasr.segmenter.join_segmenter import JoinSegmenter
from espnet2.utils.nested_dict_action import NestedDictAction
from espnet2.utils.types import int_or_none, str2bool, str_or_none

frontend_choices = ClassChoices(
    name="frontend",
    classes=dict(
        default=DefaultFrontend,
        sliding_window=SlidingWindow,
        s3prl=S3prlFrontend,
        fused=FusedFrontends,
    ),
    type_check=AbsFrontend,
    default="default",
)
segmenter_choices = ClassChoices(
    name="segmenter",
    classes=dict(
        join=JoinSegmenter,
    ),
    type_check=AbsSegmenter,
    default=None,
    optional=True,
)
discriminator_choices = ClassChoices(
    name="discriminator",
    classes=dict(
        conv=ConvDiscriminator,
    ),
    type_check=AbsDiscriminator,
    default="conv",
)
generator_choices = ClassChoices(
    name="generator",
    classes=dict(
        conv=ConvGenerator,
    ),
    type_check=AbsGenerator,
    default="conv",
)
loss_choices = ClassChoices(
    name="loss",
    classes=dict(
        discriminator_loss=UASRDiscriminatorLoss,
        gradient_penalty=UASRGradientPenalty,
        smoothness_penalty=UASRSmoothnessPenalty,
        phoneme_diversity_loss=UASRPhonemeDiversityLoss,
        pseudo_label_loss=UASRPseudoLabelLoss,
    ),
    type_check=AbsUASRLoss,
    default="discriminator_loss",
)


[docs]class UASRTask(AbsTask): # If you need more than one optimizers, change this value num_optimizers: int = 2 # Add variable objects configurations class_choices_list = [ # --frontend and --frontend_conf frontend_choices, # --segmenter and --segmenter_conf segmenter_choices, # --discriminator and --discriminator_conf discriminator_choices, # --generator and --generator_conf generator_choices, loss_choices, ] # If you need to modify train() or eval() procedures, change Trainer class here trainer = UASRTrainer
[docs] @classmethod def add_task_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group(description="Task related") # NOTE(kamo): add_arguments(..., required=True) can't be used # to provide --print_config mode. Instead of it, do as required = parser.get_default("required") required += ["token_list"] group.add_argument( "--token_list", type=str_or_none, default=None, help="A text mapping int-id to token", ) group.add_argument( "--init", type=lambda x: str_or_none(x.lower()), default=None, help="The initialization method", choices=[ "chainer", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal", None, ], ) group.add_argument( "--input_size", type=int_or_none, default=None, help="The number of input dimension of the feature", ) group = parser.add_argument_group(description="Preprocess related") group.add_argument( "--use_preprocessor", type=str2bool, default=True, help="Apply preprocessing to data or not", ) group.add_argument( "--token_type", type=str, default="phn", choices=["phn"], help="The text will be tokenized " "in the specified level token", ) group.add_argument( "--bpemodel", type=str_or_none, default=None, help="The model file of sentencepiece", ) parser.add_argument( "--non_linguistic_symbols", type=str_or_none, help="non_linguistic_symbols file path", ) group.add_argument( "--cleaner", type=str_or_none, choices=[None, "tacotron", "jaconv", "vietnamese"], default=None, help="Apply text cleaning", ) group.add_argument( "--losses", action=NestedDictAction, default=[ { "name": "discriminator_loss", "conf": {}, }, ], help="The criterions binded with the loss wrappers.", # Loss format would be like: # losses: # - name: loss1 # conf: # weight: 1.0 # smoothed: false # - name: loss2 # conf: # weight: 0.1 # smoothed: false ) group = parser.add_argument_group(description="Task related") group.add_argument( "--kenlm_path", type=str, help="path of n-gram kenlm for validation", ) parser.add_argument( "--int_pad_value", type=int, default=0, help="Integer padding value for real token sequence", ) parser.add_argument( "--fairseq_checkpoint", type=str, help="Fairseq checkpoint to initialize model", ) for class_choices in cls.class_choices_list: # Append --<name> and --<name>_conf. # e.g. --encoder and --encoder_conf class_choices.add_arguments(group)
[docs] @classmethod def build_collate_fn( cls, args: argparse.Namespace, train: bool ) -> Callable[ [Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[List[str], Dict[str, torch.Tensor]], ]: assert check_argument_types() # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol return CommonCollateFn(float_pad_value=0.0, int_pad_value=args.int_pad_value)
[docs] @classmethod def build_preprocess_fn( cls, args: argparse.Namespace, train: bool ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: assert check_argument_types() if args.use_preprocessor: retval = CommonPreprocessor( train=train, token_type=args.token_type, token_list=args.token_list, bpemodel=args.bpemodel, non_linguistic_symbols=args.non_linguistic_symbols, text_cleaner=args.cleaner, ) else: retval = None assert check_return_type(retval) return retval
[docs] @classmethod def required_data_names( cls, train: bool = True, inference: bool = False ) -> Tuple[str, ...]: if not inference: retval = ("speech", "text") else: # Recognition mode retval = ("speech",) return retval
[docs] @classmethod def optional_data_names( cls, train: bool = True, inference: bool = False ) -> Tuple[str, ...]: retval = ("pseudo_labels", "input_cluster_id") assert check_return_type(retval) return retval
[docs] @classmethod def build_model(cls, args: argparse.Namespace) -> ESPnetUASRModel: assert check_argument_types() if isinstance(args.token_list, str): with open(args.token_list, encoding="utf-8") as f: token_list = [line.rstrip() for line in f] # Overwriting token_list to keep it as "portable". args.token_list = list(token_list) elif isinstance(args.token_list, (tuple, list)): token_list = list(args.token_list) else: raise RuntimeError("token_list must be str or list") vocab_size = len(token_list) logging.info(f"Vocabulary size: {vocab_size}") # load from fairseq checkpoint load_fairseq_model = False cfg = None if args.fairseq_checkpoint is not None: load_fairseq_model = True ckpt = args.fairseq_checkpoint logging.info(f"Loading parameters from fairseq: {ckpt}") state_dict = torch.load(ckpt) if "cfg" in state_dict and state_dict["cfg"] is not None: model_cfg = state_dict["cfg"]["model"] logging.info(f"Building model from {model_cfg}") else: raise RuntimeError(f"Bad 'cfg' in state_dict of {ckpt}") # 1. frontend if args.write_collected_feats: # Extract features in the model # Note(jiatong): if we use write_collected_feats=True (we use # pre-extracted feature for training): we still initial # frontend to allow inference with raw speech signal # but the frontend is not used in training frontend_class = frontend_choices.get_class(args.frontend) frontend = frontend_class(**args.frontend_conf) if args.input_size is None: input_size = frontend.output_size() else: input_size = args.input_size else: # Give features from data-loader args.frontend = None args.frontend_conf = {} frontend = None input_size = args.input_size # 2. Segmenter if args.segmenter is not None: segmenter_class = segmenter_choices.get_class(args.segmenter) segmenter = segmenter_class(cfg=cfg, **args.segmenter_conf) else: segmenter = None # 3. Discriminator discriminator_class = discriminator_choices.get_class(args.discriminator) discriminator = discriminator_class( cfg=cfg, input_dim=vocab_size, **args.discriminator_conf ) # 4. Generator generator_class = generator_choices.get_class(args.generator) generator = generator_class( cfg=cfg, input_dim=input_size, output_dim=vocab_size, **args.generator_conf ) # 5. Loss definition losses = {} if getattr(args, "losses", None) is not None: # This check is for the compatibility when load models # that packed by older version for ctr in args.losses: logging.info("initialize loss: {}".format(ctr["name"])) if ctr["name"] == "gradient_penalty": loss = loss_choices.get_class(ctr["name"])( discriminator=discriminator, **ctr["conf"] ) else: loss = loss_choices.get_class(ctr["name"])(**ctr["conf"]) losses[ctr["name"]] = loss # 6. Build model logging.info(f"kenlm_path is: {args.kenlm_path}") model = ESPnetUASRModel( cfg=cfg, frontend=frontend, segmenter=segmenter, discriminator=discriminator, generator=generator, losses=losses, kenlm_path=args.kenlm_path, token_list=args.token_list, max_epoch=args.max_epoch, vocab_size=vocab_size, use_collected_training_feats=args.write_collected_feats, ) # FIXME(kamo): Should be done in model? # 7. Initialize if load_fairseq_model: logging.info(f"Initializing model from {ckpt}") model.load_state_dict(state_dict["model"], strict=False) else: if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
[docs] @classmethod def build_optimizers( cls, args: argparse.Namespace, model: ESPnetUASRModel, ) -> List[torch.optim.Optimizer]: # check assert hasattr(model, "generator") assert hasattr(model, "discriminator") generator_param_list = list(model.generator.parameters()) discriminator_param_list = list(model.discriminator.parameters()) # Add optional sets of model parameters if model.use_segmenter is not None: generator_param_list += list(model.segmenter.parameters()) if ( "pseudo_label_loss" in model.losses.keys() and model.losses["pseudo_label_loss"].weight > 0 ): generator_param_list += list( model.losses["pseudo_label_loss"].decoder.parameters() ) # define generator optimizer optim_generator_class = optim_classes.get(args.optim) if optim_generator_class is None: raise ValueError( f"must be one of {list(optim_classes)}: {args.optim_generator}" ) optim_generator = optim_generator_class( generator_param_list, **args.optim_conf, ) optimizers = [optim_generator] # define discriminator optimizer optim_discriminator_class = optim_classes.get(args.optim2) if optim_discriminator_class is None: raise ValueError( f"must be one of {list(optim_classes)}: {args.optim_discriminator}" ) optim_discriminator = optim_discriminator_class( discriminator_param_list, **args.optim2_conf, ) optimizers += [optim_discriminator] return optimizers