Source code for espnet2.iterators.multiple_iter_factory

import logging
from typing import Callable, Collection, Iterator

import numpy as np
from typeguard import check_argument_types

from espnet2.iterators.abs_iter_factory import AbsIterFactory


[docs]class MultipleIterFactory(AbsIterFactory): def __init__( self, build_funcs: Collection[Callable[[], AbsIterFactory]], seed: int = 0, shuffle: bool = False, ): assert check_argument_types() self.build_funcs = list(build_funcs) self.seed = seed self.shuffle = shuffle
[docs] def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator: if shuffle is None: shuffle = self.shuffle build_funcs = list(self.build_funcs) if shuffle: np.random.RandomState(epoch + self.seed).shuffle(build_funcs) for i, build_func in enumerate(build_funcs): logging.info(f"Building {i}th iter-factory...") iter_factory = build_func() assert isinstance(iter_factory, AbsIterFactory), type(iter_factory) yield from iter_factory.build_iter(epoch, shuffle)