"""Trainer module for speaker recognition."""
import argparse
import dataclasses
import logging
import time
from contextlib import contextmanager
from dataclasses import is_dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import humanfriendly
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim
from packaging.version import parse as V
from typeguard import check_argument_types
from espnet2.iterators.abs_iter_factory import AbsIterFactory
from espnet2.main_funcs.average_nbest_models import average_nbest_models
from espnet2.main_funcs.calculate_all_attentions import calculate_all_attentions
from espnet2.schedulers.abs_scheduler import (
AbsBatchStepScheduler,
AbsEpochStepScheduler,
AbsScheduler,
AbsValEpochStepScheduler,
)
from espnet2.torch_utils.add_gradient_noise import add_gradient_noise
from espnet2.torch_utils.device_funcs import to_device
from espnet2.torch_utils.recursive_op import recursive_average
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet2.train.distributed_utils import DistributedOption
from espnet2.train.reporter import Reporter, SubReporter
from espnet2.train.trainer import Trainer, TrainerOptions
from espnet2.utils.build_dataclass import build_dataclass
from espnet2.utils.eer import ComputeErrorRates, ComputeMinDcf, tuneThresholdfromScore
from espnet2.utils.kwargs2args import kwargs2args
if torch.distributed.is_available():
from torch.distributed import ReduceOp
[docs]class SpkTrainer(Trainer):
"""
Trainer.
Designed for speaker recognition.
Training will be done as closed set classification.
Validation will be open set EER calculation.
"""
def __init__(self):
raise RuntimeError("This class can't be instantiated.")
[docs] @classmethod
@torch.no_grad()
def validate_one_epoch(
cls,
model: torch.nn.Module,
iterator: Iterable[Dict[str, torch.Tensor]],
reporter: SubReporter,
options: TrainerOptions,
distributed_option: DistributedOption,
) -> None:
assert check_argument_types()
ngpu = options.ngpu
no_forward_run = options.no_forward_run
distributed = distributed_option.distributed
model.eval()
scores = []
labels = []
# [For distributed] Because iteration counts are not always equals between
# processes, send stop-flag to the other processes if iterator is finished
iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
for utt_id, batch in iterator:
assert isinstance(batch, dict), type(batch)
if distributed:
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
if iterator_stop > 0:
break
batch["utt_id"] = utt_id
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
if no_forward_run:
continue
org_shape = (batch["speech"].size(0), batch["speech"].size(1))
batch["speech"] = batch["speech"].flatten(0, 1)
batch["speech2"] = batch["speech2"].flatten(0, 1)
speech_embds = model(
speech=batch["speech"], spk_labels=None, extract_embd=True
)
speech2_embds = model(
speech=batch["speech2"], spk_labels=None, extract_embd=True
)
speech_embds = F.normalize(speech_embds, p=2, dim=1)
speech2_embds = F.normalize(speech2_embds, p=2, dim=1)
speech_embds = speech_embds.view(org_shape[0], org_shape[1], -1)
speech2_embds = speech2_embds.view(org_shape[0], org_shape[1], -1)
for i in range(speech_embds.size(0)):
score = torch.cdist(speech_embds[i], speech2_embds[i])
score = -1.0 * torch.mean(score)
scores.append(score.view(1)) # 0-dim to 1-dim tensor for cat
labels.append(batch["spk_labels"])
else:
if distributed:
iterator_stop.fill_(1)
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
scores = torch.cat(scores).type(torch.float32)
labels = torch.cat(labels).type(torch.int32).flatten()
if distributed:
# get the number of trials assigned on each GPU
length = to_device(
torch.tensor([labels.size(0)], dtype=torch.int32), "cuda"
)
lengths_all = [
to_device(torch.zeros(1, dtype=torch.int32), "cuda")
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(lengths_all, length)
scores_all = [
to_device(torch.zeros(i, dtype=torch.float32), "cuda")
for i in lengths_all
]
torch.distributed.all_gather(scores_all, scores)
scores = torch.cat(scores_all)
labels_all = [
to_device(torch.zeros(i, dtype=torch.int32), "cuda")
for i in lengths_all
]
torch.distributed.all_gather(labels_all, labels)
labels = torch.cat(labels_all)
rank = torch.distributed.get_rank()
torch.distributed.barrier()
scores = scores.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
# exception for collect_stats.
if len(scores) == 1:
reporter.register(stats=dict(eer=1.0, mindcf=1.0))
return
# predictions, ground truth, and the false acceptance rates to calculate
results = tuneThresholdfromScore(scores, labels, [1, 0.1])
eer = results[1]
fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
# p_target, c_miss, and c_falsealarm in NIST minDCF calculation
p_trg, c_miss, c_fa = 0.05, 1, 1
mindcf, _ = ComputeMinDcf(fnrs, fprs, thresholds, p_trg, c_miss, c_fa)
reporter.register(stats=dict(eer=eer, mindcf=mindcf))