Source code for espnet.utils.training.batchfy

import itertools
import logging

import numpy as np


[docs]def batchfy_by_seq( sorted_data, batch_size, max_length_in, max_length_out, min_batch_size=1, shortest_first=False, ikey="input", iaxis=0, okey="output", oaxis=0, ): """Make batch set from json dictionary :param Dict[str, Dict[str, Any]] sorted_data: dictionary loaded from data.json :param int batch_size: batch size :param int max_length_in: maximum length of input to decide adaptive batch size :param int max_length_out: maximum length of output to decide adaptive batch size :param int min_batch_size: mininum batch size (for multi-gpu) :param bool shortest_first: Sort from batch with shortest samples to longest if true, otherwise reverse :param str ikey: key to access input (for ASR ikey="input", for TTS, MT ikey="output".) :param int iaxis: dimension to access input (for ASR, TTS iaxis=0, for MT iaxis="1".) :param str okey: key to access output (for ASR, MT okey="output". for TTS okey="input".) :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.) :return: List[List[Tuple[str, dict]]] list of batches """ if batch_size <= 0: raise ValueError(f"Invalid batch_size={batch_size}") # check #utts is more than min_batch_size if len(sorted_data) < min_batch_size: raise ValueError( f"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size})." ) # make list of minibatches minibatches = [] start = 0 while True: _, info = sorted_data[start] ilen = int(info[ikey][iaxis]["shape"][0]) olen = ( int(info[okey][oaxis]["shape"][0]) if oaxis >= 0 else max(map(lambda x: int(x["shape"][0]), info[okey])) ) factor = max(int(ilen / max_length_in), int(olen / max_length_out)) # change batchsize depending on the input and output length # if ilen = 1000 and max_length_in = 800 # then b = batchsize / 2 # and max(min_batches, .) avoids batchsize = 0 bs = max(min_batch_size, int(batch_size / (1 + factor))) end = min(len(sorted_data), start + bs) minibatch = sorted_data[start:end] if shortest_first: minibatch.reverse() # check each batch is more than minimum batchsize if len(minibatch) < min_batch_size: mod = min_batch_size - len(minibatch) % min_batch_size additional_minibatch = [ sorted_data[i] for i in np.random.randint(0, start, mod) ] if shortest_first: additional_minibatch.reverse() minibatch.extend(additional_minibatch) minibatches.append(minibatch) if end == len(sorted_data): break start = end # batch: List[List[Tuple[str, dict]]] return minibatches
[docs]def batchfy_by_bin( sorted_data, batch_bins, num_batches=0, min_batch_size=1, shortest_first=False, ikey="input", okey="output", ): """Make variably sized batch set, which maximizes the number of bins up to `batch_bins`. :param Dict[str, Dict[str, Any]] sorted_data: dictionary loaded from data.json :param int batch_bins: Maximum frames of a batch :param int num_batches: # number of batches to use (for debug) :param int min_batch_size: minimum batch size (for multi-gpu) :param int test: Return only every `test` batches :param bool shortest_first: Sort from batch with shortest samples to longest if true, otherwise reverse :param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".) :param str okey: key to access output (for ASR okey="output". for TTS okey="input".) :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches """ if batch_bins <= 0: raise ValueError(f"invalid batch_bins={batch_bins}") length = len(sorted_data) idim = int(sorted_data[0][1][ikey][0]["shape"][1]) odim = int(sorted_data[0][1][okey][0]["shape"][1]) logging.info("# utts: " + str(len(sorted_data))) minibatches = [] start = 0 n = 0 while True: # Dynamic batch size depending on size of samples b = 0 next_size = 0 max_olen = 0 while next_size < batch_bins and (start + b) < length: ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0]) * idim olen = int(sorted_data[start + b][1][okey][0]["shape"][0]) * odim if olen > max_olen: max_olen = olen next_size = (max_olen + ilen) * (b + 1) if next_size <= batch_bins: b += 1 elif next_size == 0: raise ValueError( f"Can't fit one sample in batch_bins ({batch_bins}): " f"Please increase the value" ) end = min(length, start + max(min_batch_size, b)) batch = sorted_data[start:end] if shortest_first: batch.reverse() minibatches.append(batch) # Check for min_batch_size and fixes the batches if needed i = -1 while len(minibatches[i]) < min_batch_size: missing = min_batch_size - len(minibatches[i]) if -i == len(minibatches): minibatches[i + 1].extend(minibatches[i]) minibatches = minibatches[1:] break else: minibatches[i].extend(minibatches[i - 1][:missing]) minibatches[i - 1] = minibatches[i - 1][missing:] i -= 1 if end == length: break start = end n += 1 if num_batches > 0: minibatches = minibatches[:num_batches] lengths = [len(x) for x in minibatches] logging.info( str(len(minibatches)) + " batches containing from " + str(min(lengths)) + " to " + str(max(lengths)) + " samples " + "(avg " + str(int(np.mean(lengths))) + " samples)." ) return minibatches
[docs]def batchfy_by_frame( sorted_data, max_frames_in, max_frames_out, max_frames_inout, num_batches=0, min_batch_size=1, shortest_first=False, ikey="input", okey="output", ): """Make variable batch set, which maximizes the number of frames to max_batch_frame. :param Dict[str, Dict[str, Any]] sorteddata: dictionary loaded from data.json :param int max_frames_in: Maximum input frames of a batch :param int max_frames_out: Maximum output frames of a batch :param int max_frames_inout: Maximum input+output frames of a batch :param int num_batches: # number of batches to use (for debug) :param int min_batch_size: minimum batch size (for multi-gpu) :param int test: Return only every `test` batches :param bool shortest_first: Sort from batch with shortest samples to longest if true, otherwise reverse :param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".) :param str okey: key to access output (for ASR okey="output". for TTS okey="input".) :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches """ if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0: raise ValueError( "At least, one of `--batch-frames-in`, `--batch-frames-out` or " "`--batch-frames-inout` should be > 0" ) length = len(sorted_data) minibatches = [] start = 0 end = 0 while end != length: # Dynamic batch size depending on size of samples b = 0 max_olen = 0 max_ilen = 0 while (start + b) < length: ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0]) if ilen > max_frames_in and max_frames_in != 0: raise ValueError( f"Can't fit one sample in --batch-frames-in ({max_frames_in}): " f"Please increase the value" ) olen = int(sorted_data[start + b][1][okey][0]["shape"][0]) if olen > max_frames_out and max_frames_out != 0: raise ValueError( f"Can't fit one sample in --batch-frames-out ({max_frames_out}): " f"Please increase the value" ) if ilen + olen > max_frames_inout and max_frames_inout != 0: raise ValueError( f"Can't fit one sample in --batch-frames-out ({max_frames_inout}): " f"Please increase the value" ) max_olen = max(max_olen, olen) max_ilen = max(max_ilen, ilen) in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0 out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0 inout_ok = (max_ilen + max_olen) * ( b + 1 ) <= max_frames_inout or max_frames_inout == 0 if in_ok and out_ok and inout_ok: # add more seq in the minibatch b += 1 else: # no more seq in the minibatch break end = min(length, start + b) batch = sorted_data[start:end] if shortest_first: batch.reverse() minibatches.append(batch) # Check for min_batch_size and fixes the batches if needed i = -1 while len(minibatches[i]) < min_batch_size: missing = min_batch_size - len(minibatches[i]) if -i == len(minibatches): minibatches[i + 1].extend(minibatches[i]) minibatches = minibatches[1:] break else: minibatches[i].extend(minibatches[i - 1][:missing]) minibatches[i - 1] = minibatches[i - 1][missing:] i -= 1 start = end if num_batches > 0: minibatches = minibatches[:num_batches] lengths = [len(x) for x in minibatches] logging.info( str(len(minibatches)) + " batches containing from " + str(min(lengths)) + " to " + str(max(lengths)) + " samples" + "(avg " + str(int(np.mean(lengths))) + " samples)." ) return minibatches
[docs]def batchfy_shuffle(data, batch_size, min_batch_size, num_batches, shortest_first): import random logging.info("use shuffled batch.") sorted_data = random.sample(data.items(), len(data.items())) logging.info("# utts: " + str(len(sorted_data))) # make list of minibatches minibatches = [] start = 0 while True: end = min(len(sorted_data), start + batch_size) # check each batch is more than minimum batchsize minibatch = sorted_data[start:end] if shortest_first: minibatch.reverse() if len(minibatch) < min_batch_size: mod = min_batch_size - len(minibatch) % min_batch_size additional_minibatch = [ sorted_data[i] for i in np.random.randint(0, start, mod) ] if shortest_first: additional_minibatch.reverse() minibatch.extend(additional_minibatch) minibatches.append(minibatch) if end == len(sorted_data): break start = end # for debugging if num_batches > 0: minibatches = minibatches[:num_batches] logging.info("# minibatches: " + str(len(minibatches))) return minibatches
BATCH_COUNT_CHOICES = ["auto", "seq", "bin", "frame"] BATCH_SORT_KEY_CHOICES = ["input", "output", "shuffle"]
[docs]def make_batchset( data, batch_size=0, max_length_in=float("inf"), max_length_out=float("inf"), num_batches=0, min_batch_size=1, shortest_first=False, batch_sort_key="input", swap_io=False, mt=False, count="auto", batch_bins=0, batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, iaxis=0, oaxis=0, ): """Make batch set from json dictionary if utts have "category" value, >>> data = {'utt1': {'category': 'A', 'input': ...}, ... 'utt2': {'category': 'B', 'input': ...}, ... 'utt3': {'category': 'B', 'input': ...}, ... 'utt4': {'category': 'A', 'input': ...}} >>> make_batchset(data, batchsize=2, ...) [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]] Note that if any utts doesn't have "category", perform as same as batchfy_by_{count} :param Dict[str, Dict[str, Any]] data: dictionary loaded from data.json :param int batch_size: maximum number of sequences in a minibatch. :param int batch_bins: maximum number of bins (frames x dim) in a minibatch. :param int batch_frames_in: maximum number of input frames in a minibatch. :param int batch_frames_out: maximum number of output frames in a minibatch. :param int batch_frames_out: maximum number of input+output frames in a minibatch. :param str count: strategy to count maximum size of batch. For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES :param int max_length_in: maximum length of input to decide adaptive batch size :param int max_length_out: maximum length of output to decide adaptive batch size :param int num_batches: # number of batches to use (for debug) :param int min_batch_size: minimum batch size (for multi-gpu) :param bool shortest_first: Sort from batch with shortest samples to longest if true, otherwise reverse :param str batch_sort_key: how to sort data before creating minibatches ["input", "output", "shuffle"] :param bool swap_io: if True, use "input" as output and "output" as input in `data` dict :param bool mt: if True, use 0-axis of "output" as output and 1-axis of "output" as input in `data` dict :param int iaxis: dimension to access input (for ASR, TTS iaxis=0, for MT iaxis="1".) :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.) :return: List[List[Tuple[str, dict]]] list of batches """ # check args if count not in BATCH_COUNT_CHOICES: raise ValueError( f"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}" ) if batch_sort_key not in BATCH_SORT_KEY_CHOICES: raise ValueError( f"arg 'batch_sort_key' ({batch_sort_key}) should be " f"one of {BATCH_SORT_KEY_CHOICES}" ) # TODO(karita): remove this by creating converter from ASR to TTS json format batch_sort_axis = 0 if swap_io: # for TTS ikey = "output" okey = "input" if batch_sort_key == "input": batch_sort_key = "output" elif batch_sort_key == "output": batch_sort_key = "input" elif mt: # for MT ikey = "output" okey = "output" batch_sort_key = "output" batch_sort_axis = 1 assert iaxis == 1 assert oaxis == 0 # NOTE: input is json['output'][1] and output is json['output'][0] else: ikey = "input" okey = "output" if count == "auto": if batch_size != 0: count = "seq" elif batch_bins != 0: count = "bin" elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0: count = "frame" else: raise ValueError( f"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}" ) logging.info(f"count is auto detected as {count}") if count != "seq" and batch_sort_key == "shuffle": raise ValueError("batch_sort_key=shuffle is only available if batch_count=seq") category2data = {} # Dict[str, dict] for k, v in data.items(): category2data.setdefault(v.get("category"), {})[k] = v batches_list = [] # List[List[List[Tuple[str, dict]]]] for d in category2data.values(): if batch_sort_key == "shuffle": batches = batchfy_shuffle( d, batch_size, min_batch_size, num_batches, shortest_first ) batches_list.append(batches) continue # sort it by input lengths (long to short) sorted_data = sorted( d.items(), key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]), reverse=not shortest_first, ) logging.info("# utts: " + str(len(sorted_data))) if count == "seq": batches = batchfy_by_seq( sorted_data, batch_size=batch_size, max_length_in=max_length_in, max_length_out=max_length_out, min_batch_size=min_batch_size, shortest_first=shortest_first, ikey=ikey, iaxis=iaxis, okey=okey, oaxis=oaxis, ) if count == "bin": batches = batchfy_by_bin( sorted_data, batch_bins=batch_bins, min_batch_size=min_batch_size, shortest_first=shortest_first, ikey=ikey, okey=okey, ) if count == "frame": batches = batchfy_by_frame( sorted_data, max_frames_in=batch_frames_in, max_frames_out=batch_frames_out, max_frames_inout=batch_frames_inout, min_batch_size=min_batch_size, shortest_first=shortest_first, ikey=ikey, okey=okey, ) batches_list.append(batches) if len(batches_list) == 1: batches = batches_list[0] else: # Concat list. This way is faster than "sum(batch_list, [])" batches = list(itertools.chain(*batches_list)) # for debugging if num_batches > 0: batches = batches[:num_batches] logging.info("# minibatches: " + str(len(batches))) # batch: List[List[Tuple[str, dict]]] return batches