Task class and data input system for training

Task class

In ESpnet1, we have too many duplicated python modules. One of the big purposes of ESPnet2 is to provide a common interface and enable us to focus more on the unique parts of each task.

Task class is a common system to build training tools for each task, ASR, TTS, LM, etc. inspired by Fairseq Task idea. To build your task, only you have to do is just inheriting AbsTask class:

from espnet2.tasks.abs_task import AbsTask
from espnet2.train.abs_espnet_model import AbsESPnetModel

class NewModel(ESPnetModel):
    def forward(self, input, target):
        (...)
        # loss: The loss of the task. Must be a scalar value.
        # stats: A dict object, used for logging and validation criterion
        # weight: A scalar value that is used for normalization of loss and stats values among each mini-batches.
        #     In many cases, this value should be equal to the mini-batch-size
        return loss, stats, weight

class NewTask(AbsTask):
    @classmethod
    def add_task_arguments(cls, parser):
        parser.add_arguments(...)
        (...)

    @classmethod
    def build_collate_fn(cls, args: argparse.Namespace)
        (...)

    @classmethod
    def build_preprocess_fn(cls, args, train):
        (...)

    @classmethod
    def required_data_names(cls, inference: bool = False):
        (...)

    @classmethod
    def optional_data_names(cls, inference: bool = False):
        (...)

    @classmethod
    def build_model(cls, args):
        return NewModel(...)

if __name__ == "__main__":
    # Start training
    NewTask.main()

Data input system

Espnet2 also provides a command line interface to describe the training corpus. On the contrary, unlike fairseq or training system such as pytorch-lightning, our Task class doesn’t have an interface for building the dataset explicitly. This is because we aim at the task related to speech/text only, so we don’t need such general system so far.

The following is an example of the command lint arguments:

python -m espnet2.bin.asr_train \
  --train_data_path_and_name_and_type=/some/path/tr/wav.scp,speech,sound \
  --train_data_path_and_name_and_type=/some/path/tr/token_int,text,text_int \
  --valid_data_path_and_name_and_type=/some/path/dev/wav.scp,speech,sound \
  --valid_data_path_and_name_and_type=/some/path/dev/token_int,text,text_int

First of all, our mini-batch is always a dict object:

# In training iteration
for batch in iterator:
    # e.g. batch = {"speech": ..., "text": ...}
    # Forward
    model(**batch)

Where the model is same as the model built by Task.build_model().

You can flexibly construct this mini-batch object using --*_data_path_and_name_and_type. --*_data_path_and_name_and_type can be repeated as you need and each --*_data_path_and_name_and_type corresponds to an element in the mini-batch. Also, keep in mind that there is no distinction between input and target data.

The argument of --train_data_path_and_name_and_type should be given as three values separated by commas, like <file-path>,<key-name>,<file-format>.

  • key-name specify the key of dict

  • file-path is a file/directory path for the data source.

  • file-format indicates the format of file specified by file-path. e.g. sound, kaldi_ark, or etc.

scp file

You can show the supported file format using --help option.

python -m espnet2.bin.asr_train --help

Almost all formats are referred as scp file according to Kaldi-ASR. scp is just a text file which has two columns for each line: The first indicates the sample id and the second is some value. e.g. file path, transcription, a sequence of numbers.

  • format=npy

    sample_id_a /some/path/a.npy
    sample_id_b /some/path/b.npy
    
  • format=sound

    sample_id_a /some/path/a.flac
    sample_id_b /some/path/a.wav
    
  • format=kaldi_ark

    sample_id_a /some/path/a.ark:1234
    sample_id_b /some/path/a.ark:5678
    
  • format=text_int

    sample_id_a 10 2 4 4
    sample_id_b 3 2 0 1 6 2
    
  • format=text

    sample_id_a hello world
    sample_id_b It is rainy today
    

required_data_names() and optional_data_names()

Though an arbitrary dictionary can be created by this system, each task assumes that the specific key is given for a specific purpose. e.g. ASR Task requires speech and text keys and each value is used for input data and target data respectively. See again the methods of Task class: required_data_names() and optional_data_names().

class NewTask(AbsTask):
  @classmethod
  def required_data_names(cls, inference: bool = False):
      if not inference:
          retval = ("input", "target")
      else:
          retval = ("input",)
      return retval

  @classmethod
  def optional_data_names(cls, inference: bool = False):
      retval = ("auxially_feature",)
      return retval

required_data_names() determines the mandatory data names and optional_data_names() gives optional data. It means that the other names are allowed to given by command line arguments.

# The following is the expected argument
python -m new_task \
  --train_data_path_and_name_and_type=filepath,input,sometype \
  --train_data_path_and_name_and_type=filepath,target,sometype \
  --train_data_path_and_name_and_type=filepath,auxially_feature,sometype
# The following raises an error
python -m new_task \
  --train_data_path_and_name_and_type=filepath,unknown,sometype

The intention of this system is just an assertion check, so if feel unnecessary, you can turn off this checking with --allow_variable_data_keys true.

# Ignore assertion checking for data names
python -m new_task \
  --train_data_path_and_name_and_type=filepath,unknown_name,sometype \
  --allow_variable_data_keys true

Customize collcate_fn for PyTorch data loader

Task class has a method to customize collcate_fn:

class NewTask(AbsTask):
  @classmethod
  def build_collate_fn(cls, args: argparse.Namespace):
    ...

collcate_fn is an argument of torch.utils.data.DataLoader and it can modify the data which is received from data-loader. e.g.:

def collcate_fn(data):
    # data is a list of the return value of Dataset class:
    modified_data = (...touch data)
    return modified_data

from torch.utils.data import DataLoader
data_loader = DataLoader(dataset, collcate_fn=collcate_fn)
for modified_data in data_loader:
    ...

The type of argument is determined by the input dataset class and our dataset is always espnet2.train.dataset.ESPnetDataset, which the return value is a tuple of sample id and a dict of tensor,

batch = ("sample_id", {"speech": tensor, "text": tensor})

Therefore, the type is a list of dict of tensor.

data = [
  ("sample_id", {"speech": tensor, "text": tensor}),
  ("sample_id2", {"speech": tensor, "text": tensor}),
  ...
]

The return type of collate_fn is supposed to be a tuple of list and a dict of tensor in espnet2, so the collcate_fn for Task must transform the data type to it.

for ids, batch in data_loader:
  model(**batch)

We provide common collate_fn and this function can support many cases, so you might not need to customize it. This collate_fn is aware of variable sequence features for seq2seq task:

  • The first axis of the sequence tensor from dataset must be length axis: e.g. (Length, Dim), (Length, Dim, Dim2), or (Length, …)

  • It’s not necessary to make the lengths of each sample unified and they are stacked with zero-padding.

    • The value of padding can be changed.

      from espnet2.train.collate_fn import CommonCollateFn
      @classmethod
      def build_collate_fn(cls, args):
          # float_pad_value is used for float-tensor and int_pad_value is used for int-tensor
          return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
      
  • Tensors which represent the length of each samples are also appended

    batch = {"speech": ..., "speech_lengths": ..., "text": ..., "text_lengths": ...}
    
  • If the feature is not sequential data, this behavior can be disabled.

    python -m new_task --train_data_path_and_name_and_type=filepath,foo,npy
    
    @classmethod
    def build_collate_fn(cls, args):
        return CommonCollateFn(not_sequence=["foo"])