import argparse
import logging
from typing import Dict, Optional
import torch
from typeguard import check_argument_types
from espnet2.uasr.generator.abs_generator import AbsGenerator
from espnet2.utils.types import str2bool
[docs]class TransposeLast(torch.nn.Module):
def __init__(self, deconstruct_idx=None):
super().__init__()
self.deconstruct_idx = deconstruct_idx
[docs] def forward(self, x):
if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx]
return x.transpose(-2, -1)
[docs]class SamePad(torch.nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
[docs] def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
[docs]class ConvGenerator(AbsGenerator):
"""convolutional generator for UASR."""
def __init__(
self,
input_dim: int,
output_dim: int,
cfg: Optional[Dict] = None,
conv_kernel: int = 3,
conv_dilation: int = 1,
conv_stride: int = 9,
pad: int = -1,
bias: str2bool = False,
dropout: float = 0.0,
batch_norm: str2bool = True,
batch_norm_weight: float = 30.0,
residual: str2bool = True,
):
super().__init__()
assert check_argument_types()
self.input_dim = input_dim
self.output_dim = output_dim
if cfg is not None:
cfg = argparse.Namespace(**cfg)
self.conv_kernel = cfg.generator_kernel
self.conv_dilation = cfg.generator_dilation
self.conv_stride = cfg.generator_stride
self.pad = cfg.generator_pad
self.bias = cfg.generator_bias
self.dropout = torch.nn.Dropout(cfg.generator_dropout)
# TODO(Dongji): batch_norm is not in cfg
self.batch_norm = False
self.batch_norm_weight = cfg.generator_batch_norm
self.residual = cfg.generator_residual
else:
self.conv_kernel = conv_kernel
self.conv_dilation = conv_dilation
self.conv_stride = conv_stride
self.output_dim = output_dim
self.pad = pad
self.bias = bias
self.dropout = torch.nn.Dropout(dropout)
self.batch_norm = batch_norm
self.batch_norm_weight = batch_norm_weight
self.residual = residual
if self.pad < 0:
self.padding = self.conv_kernel // 2
else:
self.padding = self.pad
self.proj = torch.nn.Sequential(
TransposeLast(),
torch.nn.Conv1d(
input_dim,
output_dim,
kernel_size=self.conv_kernel,
stride=self.conv_stride,
dilation=self.conv_dilation,
padding=self.padding,
bias=self.bias,
),
TransposeLast(),
)
if self.batch_norm:
self.bn = torch.nn.BatchNorm1d(input_dim)
self.bn.weight.data.fill_(self.batch_norm_weight)
if self.residual:
self.in_proj = torch.nn.Linear(input_dim, input_dim)
[docs] def output_size(self):
return self.output_dim
[docs] def forward(
self,
feats: torch.Tensor,
text: Optional[torch.Tensor],
feats_padding_mask: torch.Tensor,
):
inter_x = None
if self.batch_norm:
feats = self.bn_padded_data(feats, feats_padding_mask)
if self.residual:
inter_x = self.in_proj(self.dropout(feats))
feats = feats + inter_x
feats = self.dropout(feats)
generated_sample = self.proj(feats)
generated_sample_padding_mask = feats_padding_mask[:, :: self.conv_stride]
if generated_sample_padding_mask.size(1) != generated_sample.size(1):
new_padding = generated_sample_padding_mask.new_zeros(
generated_sample.shape[:-1]
)
diff = new_padding.size(1) - generated_sample_padding_mask.size(1)
if diff > 0:
new_padding[:, diff:] = generated_sample_padding_mask
else:
logging.info("ATTENTION: make sure that you are using V2 instead of V1")
assert diff < 0
new_padding = generated_sample_padding_mask[:, :diff]
generated_sample_padding_mask = new_padding
real_sample = None
if text is not None:
assert torch.count_nonzero(text) > 0
real_sample = generated_sample.new_zeros(text.numel(), self.output_dim)
real_sample.scatter_(1, text.view(-1, 1).long(), 1)
real_sample = real_sample.view(text.shape + (self.output_dim,))
return generated_sample, real_sample, inter_x, generated_sample_padding_mask
[docs] def bn_padded_data(self, feature: torch.Tensor, padding_mask: torch.Tensor):
normed_feature = feature.clone()
normed_feature[~padding_mask] = self.bn(
feature[~padding_mask].unsqueeze(-1)
).squeeze(-1)
return normed_feature