import argparse
import logging
from contextlib import contextmanager
from typing import Dict, Optional, Tuple
import editdistance
import torch
import torch.nn.functional as F
from packaging.version import parse as V
from typeguard import check_argument_types
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.text.token_id_converter import TokenIDConverter
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet2.uasr.discriminator.abs_discriminator import AbsDiscriminator
from espnet2.uasr.generator.abs_generator import AbsGenerator
from espnet2.uasr.loss.abs_loss import AbsUASRLoss
from espnet2.uasr.segmenter.abs_segmenter import AbsSegmenter
from espnet2.utils.types import str2bool
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
try:
import kenlm # for CI import
except ImportError or ModuleNotFoundError:
kenlm = None
[docs]class ESPnetUASRModel(AbsESPnetModel):
"""Unsupervised ASR model.
The source code is from FAIRSEQ:
https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec/unsupervised
"""
def __init__(
self,
frontend: Optional[AbsFrontend],
segmenter: Optional[AbsSegmenter],
generator: AbsGenerator,
discriminator: AbsDiscriminator,
losses: Dict[str, AbsUASRLoss],
kenlm_path: Optional[str],
token_list: Optional[list],
max_epoch: Optional[int],
vocab_size: int,
cfg: Optional[Dict] = None,
pad: int = 1,
sil_token: str = "<SIL>",
sos_token: str = "<s>",
eos_token: str = "</s>",
skip_softmax: str2bool = False,
use_gumbel: str2bool = False,
use_hard_gumbel: str2bool = True,
min_temperature: float = 0.1,
max_temperature: float = 2.0,
decay_temperature: float = 0.99995,
use_collected_training_feats: str2bool = False,
):
assert check_argument_types()
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.frontend = frontend
self.segmenter = segmenter
self.use_segmenter = True if segmenter is not None else False
self.generator = generator
self.discriminator = discriminator
self.pad = pad
if cfg is not None:
cfg = argparse.Namespace(**cfg)
self.skip_softmax = cfg.no_softmax
self.use_gumbel = cfg.gumbel
self.use_hard_gumbel = cfg.hard_gumbel
else:
self.skip_softmax = skip_softmax
self.use_gumbel = use_gumbel
self.use_hard_gumbel = use_hard_gumbel
self.use_collected_training_feats = use_collected_training_feats
self.min_temperature = min_temperature
self.max_temperature = max_temperature
self.decay_temperature = decay_temperature
self.current_temperature = max_temperature
self._number_updates = 0
self._number_epochs = 0
self.max_epoch = max_epoch
# for loss registration
self.losses = torch.nn.ModuleDict(losses)
# for validation
self.vocab_size = vocab_size
self.token_list = token_list
self.token_id_converter = TokenIDConverter(token_list=token_list)
self.sil = self.token_id_converter.tokens2ids([sil_token])[0]
self.sos = self.token_id_converter.tokens2ids([sos_token])[0]
self.eos = self.token_id_converter.tokens2ids([eos_token])[0]
self.kenlm = None
assert (
kenlm is not None
), "kenlm is not installed, please install from tools/installers"
if kenlm_path:
self.kenlm = kenlm.Model(kenlm_path)
@property
def number_updates(self):
return self._number_updates
@number_updates.setter
def number_updates(self, iiter: int):
assert check_argument_types() and iiter >= 0
self._number_updates = iiter
[docs] def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: Optional[torch.Tensor] = None,
text_lengths: Optional[torch.Tensor] = None,
pseudo_labels: Optional[torch.Tensor] = None,
pseudo_labels_lengths: Optional[torch.Tensor] = None,
do_validation: Optional[str2bool] = False,
print_hyp: Optional[str2bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Segmenter + Generator + Discriminator + Calc Loss
Args:
"""
stats = {}
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (
speech.shape,
speech_lengths.shape,
text.shape,
text_lengths.shape,
)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
# 1. Feats encode (Extract feats + Apply segmenter)
feats, padding_mask = self.encode(speech, speech_lengths)
# 2. Generate fake samples
(
generated_sample,
real_sample,
x_inter,
generated_sample_padding_mask,
) = self.generator(feats, text, padding_mask)
# 3. Reprocess segments
if self.use_segmenter:
(
generated_sample,
generated_sample_padding_mask,
) = self.segmenter.logit_segment(
generated_sample, generated_sample_padding_mask
)
# for phone_diversity_loss
generated_sample_logits = generated_sample
if not self.skip_softmax:
if self.training and self.use_gumbel:
generated_sample = F.gumbel_softmax(
generated_sample_logits.float(),
tau=self.curr_temp,
hard=self.use_hard_gumbel,
).type_as(generated_sample_logits)
else:
generated_sample = generated_sample_logits.softmax(-1)
# for validation
vocab_seen = None
if do_validation:
batch_num_errors = 0
batched_hyp_ids = generated_sample.argmax(-1)
batched_hyp_ids[generated_sample_padding_mask] = self.pad
# for kenlm ppl metric
batch_lm_log_prob = 0
batch_num_hyp_tokens = 0
vocab_seen = torch.zeros(self.vocab_size - 4, dtype=torch.bool)
for hyp_ids, ref_ids in zip(batched_hyp_ids, text):
# remove <pad> and <unk>
hyp_ids = hyp_ids[hyp_ids >= 4]
# remove duplicate tokens
hyp_ids = hyp_ids.unique_consecutive()
# remove silence
hyp_ids_nosil = hyp_ids[hyp_ids != self.sil]
hyp_ids_nosil_list = hyp_ids_nosil.tolist()
if self.kenlm:
hyp_token_list = self.token_id_converter.ids2tokens(
integers=hyp_ids
)
hyp_tokens = " ".join(hyp_token_list)
lm_log_prob = self.kenlm.score(hyp_tokens)
batch_lm_log_prob += lm_log_prob
batch_num_hyp_tokens += len(hyp_token_list)
hyp_tokens_index = hyp_ids[hyp_ids >= 4]
vocab_seen[hyp_tokens_index - 4] = True
ref_ids = ref_ids[ref_ids != self.pad]
ref_ids_list = ref_ids.tolist()
num_errors = editdistance.eval(hyp_ids_nosil_list, ref_ids_list)
batch_num_errors += num_errors
stats["batch_num_errors"] = batch_num_errors
stats["batch_num_ref_tokens"] = text_lengths.sum().item()
if self.kenlm:
stats["batch_lm_log_prob"] = batch_lm_log_prob
stats["batch_num_hyp_tokens"] = batch_num_hyp_tokens
stats["batch_size"] = batch_size
# print the last sample in the batch
if print_hyp:
hyp_token_list = self.token_id_converter.ids2tokens(
integers=hyp_ids_nosil
)
hyp_tokens = " ".join(hyp_token_list)
ref_token_list = self.token_id_converter.ids2tokens(integers=ref_ids)
ref_tokens = " ".join(ref_token_list)
logging.info(f"[REF]: {ref_tokens}")
logging.info(f"[HYP]: {hyp_tokens}")
real_sample_padding_mask = text == self.pad
# 5. Discriminator condition
generated_sample_prediction = self.discriminator(
generated_sample, generated_sample_padding_mask
)
real_sample_prediction = self.discriminator(
real_sample, real_sample_padding_mask
)
is_discriminative_step = self.is_discriminative_step()
# 5. Calculate losses
loss_info = []
if "discriminator_loss" in self.losses.keys():
(
generated_sample_prediction_loss,
real_sample_prediction_loss,
) = self.losses["discriminator_loss"](
generated_sample_prediction,
real_sample_prediction,
is_discriminative_step,
)
loss_info.append(
generated_sample_prediction_loss
* self.losses["discriminator_loss"].weight
)
if is_discriminative_step:
loss_info.append(
real_sample_prediction_loss
* self.losses["discriminator_loss"].weight
)
else:
generated_sample_prediction_loss, real_sample_prediction_loss = None, None
if "gradient_penalty" in self.losses.keys():
gp = self.losses["gradient_penalty"](
generated_sample,
real_sample,
self.training,
is_discriminative_step,
)
loss_info.append(gp * self.losses["gradient_penalty"].weight)
loss_info.append(gp * self.losses["gradient_penalty"].weight)
else:
gp = None
if "phoneme_diversity_loss" in self.losses.keys():
pdl = self.losses["phoneme_diversity_loss"](
generated_sample_logits, batch_size, is_discriminative_step
)
loss_info.append(pdl * self.losses["phoneme_diversity_loss"].weight)
else:
pdl = None
if "smoothness_penalty" in self.losses.keys():
sp = self.losses["smoothness_penalty"](
generated_sample_logits,
generated_sample_padding_mask,
batch_size,
is_discriminative_step,
)
loss_info.append(sp * self.losses["smoothness_penalty"].weight)
else:
sp = None
if "pseudo_label_loss" in self.losses.keys() and pseudo_labels is not None:
mmi = self.losses["pseudo_label_loss"](
x_inter, pseudo_labels, is_discriminative_step
)
loss_info.append(mmi * self.losses["pseudo_label_loss"].weight)
else:
mmi = None
# Update temperature
self._change_temperature()
self.number_updates += 1
loss = sum(loss_info)
# Collect total loss stats
stats["loss"] = loss.detach()
stats["generated_sample_prediction_loss"] = generated_sample_prediction_loss
stats["real_sample_prediction_loss"] = real_sample_prediction_loss
stats["gp"] = gp
stats["sp"] = sp
stats["pdl"] = pdl
stats["mmi"] = mmi
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight, vocab_seen
[docs] def inference(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
# 1. Feats encode (Extract feats + Apply segmenter)
feats, padding_mask = self.encode(speech, speech_lengths)
# 2. Generate fake samples
(
generated_sample,
_,
x_inter,
generated_sample_padding_mask,
) = self.generator(feats, None, padding_mask)
# generated_sample = generated_sample.softmax(-1)
return generated_sample, generated_sample_padding_mask
[docs] def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: Optional[torch.Tensor] = None,
text_lengths: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
speech = F.layer_norm(speech, speech.shape)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return {"feats": feats, "feats_lengths": feats_lengths}
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None and not self.use_collected_training_feats:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
speech = F.layer_norm(speech, speech.shape)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract (usually with pre-extracted feat)
# logging.info("use exisitng features")
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
[docs] def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
padding_mask = make_pad_mask(feats_lengths).to(feats.device)
# 2. Apply feats
if self.use_segmenter:
feats, padding_mask = self.segmenter.pre_segment(feats, padding_mask)
return feats, padding_mask
[docs] def is_discriminative_step(self):
return self.number_updates % 2 == 1
[docs] def get_optim_index(self):
return self.number_updates % 2
def _change_temperature(self):
self.current_temperature = max(
self.max_temperature * self.decay_temperature**self.number_updates,
self.min_temperature,
)