import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.nn.parallel import data_parallel
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from espnet2.fileio.datadir_writer import DatadirWriter
from espnet2.fileio.npy_scp import NpyScpWriter
from espnet2.torch_utils.device_funcs import to_device
from espnet2.torch_utils.forward_adaptor import ForwardAdaptor
from espnet2.train.abs_espnet_model import AbsESPnetModel
[docs]@torch.no_grad()
def collect_stats(
model: Union[AbsESPnetModel, None],
train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
output_dir: Path,
ngpu: Optional[int],
log_interval: Optional[int],
write_collected_feats: bool,
) -> None:
"""Perform on collect_stats mode.
Running for deriving the shape information from data
and gathering statistics.
This method is used before executing train().
"""
assert check_argument_types()
npy_scp_writers = {}
for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):
if log_interval is None:
try:
log_interval = max(len(itr) // 20, 10)
except TypeError:
log_interval = 100
sum_dict = defaultdict(lambda: 0)
sq_dict = defaultdict(lambda: 0)
count_dict = defaultdict(lambda: 0)
with DatadirWriter(output_dir / mode) as datadir_writer:
for iiter, (keys, batch) in enumerate(itr, 1):
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
# 1. Write shape file
for name in batch:
if name.endswith("_lengths"):
continue
for i, (key, data) in enumerate(zip(keys, batch[name])):
if f"{name}_lengths" in batch:
lg = int(batch[f"{name}_lengths"][i])
data = data[:lg]
datadir_writer[f"{name}_shape"][key] = ",".join(
map(str, data.shape)
)
if model is not None:
# 2. Extract feats
if ngpu <= 1:
data = model.collect_feats(**batch)
else:
# Note that data_parallel can parallelize only "forward()"
data = data_parallel(
ForwardAdaptor(model, "collect_feats"),
(),
range(ngpu),
module_kwargs=batch,
)
# 3. Calculate sum and square sum
for key, v in data.items():
for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())):
# Truncate zero-padding region
if f"{key}_lengths" in data:
length = data[f"{key}_lengths"][i]
# seq: (Length, Dim, ...)
seq = seq[:length]
else:
# seq: (Dim, ...) -> (1, Dim, ...)
seq = seq[None]
# Accumulate value, its square, and count
sum_dict[key] += seq.sum(0)
sq_dict[key] += (seq**2).sum(0)
count_dict[key] += len(seq)
# 4. [Option] Write derived features as npy format file.
if write_collected_feats:
# Instantiate NpyScpWriter for the first iteration
if (key, mode) not in npy_scp_writers:
p = output_dir / mode / "collect_feats"
npy_scp_writers[(key, mode)] = NpyScpWriter(
p / f"data_{key}", p / f"{key}.scp"
)
# Save array as npy file
npy_scp_writers[(key, mode)][uttid] = seq
if iiter % log_interval == 0:
logging.info(f"Niter: {iiter}")
for key in sum_dict:
np.savez(
output_dir / mode / f"{key}_stats.npz",
count=count_dict[key],
sum=sum_dict[key],
sum_square=sq_dict[key],
)
# batch_keys and stats_keys are used by aggregate_stats_dirs.py
with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f:
f.write(
"\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n"
)
with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f:
f.write("\n".join(sum_dict) + "\n")