import math
from typing import Collection, Dict, List, Tuple, Union
import numpy as np
import torch
from typeguard import check_argument_types, check_return_type
from espnet.nets.pytorch_backend.nets_utils import pad_list
[docs]class CommonCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(
self,
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
):
assert check_argument_types()
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
def __repr__(self):
return (
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})"
)
def __call__(
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
return common_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
)
[docs]class HuBERTCollateFn(CommonCollateFn):
"""Functor class of common_collate_fn()"""
def __init__(
self,
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
label_downsampling: int = 1,
pad: bool = False,
rand_crop: bool = True,
crop_audio: bool = True,
not_sequence: Collection[str] = (),
):
assert check_argument_types()
super().__init__(
float_pad_value=float_pad_value,
int_pad_value=int_pad_value,
not_sequence=not_sequence,
)
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.label_downsampling = label_downsampling
self.pad = pad
self.rand_crop = rand_crop
self.crop_audio = crop_audio
self.not_sequence = set(not_sequence)
def __repr__(self):
return (
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value}, "
f"label_downsampling={self.label_downsampling}, "
f"pad_value={self.pad}, rand_crop={self.rand_crop}) "
)
def __call__(
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
assert "speech" in data[0][1]
assert "text" in data[0][1]
if self.pad:
num_frames = max([sample["speech"].shape[0] for uid, sample in data])
else:
num_frames = min([sample["speech"].shape[0] for uid, sample in data])
new_data = []
for uid, sample in data:
waveform, label = sample["speech"], sample["text"]
assert waveform.ndim == 1
length = waveform.size
# The MFCC feature is 10ms per frame, while the HuBERT's transformer output
# is 20ms per frame. Downsample the KMeans label if it's generated by MFCC
# features.
if self.label_downsampling > 1:
label = label[:: self.label_downsampling]
if self.crop_audio:
waveform, label, length = _crop_audio_label(
waveform, label, length, num_frames, self.rand_crop
)
new_data.append((uid, dict(speech=waveform, text=label)))
return common_collate_fn(
new_data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
)
def _crop_audio_label(
waveform: torch.Tensor,
label: torch.Tensor,
length: torch.Tensor,
num_frames: int,
rand_crop: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Collate the audio and label at the same time.
Args:
waveform (Tensor): The waveform Tensor with dimensions `(time)`.
label (Tensor): The label Tensor with dimensions `(seq)`.
length (Tensor): The length Tensor with dimension `(1,)`.
num_frames (int): The final length of the waveform.
rand_crop (bool): if ``rand_crop`` is True, the starting index of the
waveform and label is random if the length is longer than the minimum
length in the mini-batch.
Returns:
(Tuple(Tensor, Tensor, Tensor)): Returns the Tensors for the waveform,
label, and the waveform length.
"""
kernel_size = 25
stride = 20
sample_rate = 16 # 16 per millisecond
frame_offset = 0
if waveform.size > num_frames and rand_crop:
diff = waveform.size - num_frames
frame_offset = torch.randint(diff, size=(1,))
elif waveform.size < num_frames:
num_frames = waveform.size
label_offset = max(
math.floor((frame_offset - kernel_size * sample_rate) / (stride * sample_rate))
+ 1,
0,
)
num_label = (
math.floor((num_frames - kernel_size * sample_rate) / (stride * sample_rate))
+ 1
)
waveform = waveform[frame_offset : frame_offset + num_frames]
label = label[label_offset : label_offset + num_label]
length = num_frames
return waveform, label, length
[docs]def common_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
Examples:
>>> from espnet2.samplers.constant_batch_sampler import ConstantBatchSampler,
>>> import espnet2.tasks.abs_task
>>> from espnet2.train.dataset import ESPnetDataset
>>> sampler = ConstantBatchSampler(...)
>>> dataset = ESPnetDataset(...)
>>> keys = next(iter(sampler)
>>> batch = [dataset[key] for key in keys]
>>> batch = common_collate_fn(batch)
>>> model(**batch)
Note that the dict-keys of batch are propagated from
that of the dataset as they are.
"""
assert check_argument_types()
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(
not k.endswith("_lengths") for k in data[0]
), f"*_lengths is reserved: {list(data[0])}"
output = {}
for key in data[0]:
# NOTE(kamo):
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list = [torch.from_numpy(a) for a in array_list]
# tensor: (Batch, Length, ...)
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
# lens: (Batch,)
if key not in not_sequence:
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
output[key + "_lengths"] = lens
output = (uttids, output)
assert check_return_type(output)
return output