Source code for espnet2.uasr.generator.conv_generator

import argparse
import logging
from typing import Dict, Optional

import torch
from typeguard import check_argument_types

from espnet2.uasr.generator.abs_generator import AbsGenerator
from espnet2.utils.types import str2bool


[docs]class TransposeLast(torch.nn.Module): def __init__(self, deconstruct_idx=None): super().__init__() self.deconstruct_idx = deconstruct_idx
[docs] def forward(self, x): if self.deconstruct_idx is not None: x = x[self.deconstruct_idx] return x.transpose(-2, -1)
[docs]class SamePad(torch.nn.Module): def __init__(self, kernel_size, causal=False): super().__init__() if causal: self.remove = kernel_size - 1 else: self.remove = 1 if kernel_size % 2 == 0 else 0
[docs] def forward(self, x): if self.remove > 0: x = x[:, :, : -self.remove] return x
[docs]class ConvGenerator(AbsGenerator): """convolutional generator for UASR.""" def __init__( self, input_dim: int, output_dim: int, cfg: Optional[Dict] = None, conv_kernel: int = 3, conv_dilation: int = 1, conv_stride: int = 9, pad: int = -1, bias: str2bool = False, dropout: float = 0.0, batch_norm: str2bool = True, batch_norm_weight: float = 30.0, residual: str2bool = True, ): super().__init__() assert check_argument_types() self.input_dim = input_dim self.output_dim = output_dim if cfg is not None: cfg = argparse.Namespace(**cfg) self.conv_kernel = cfg.generator_kernel self.conv_dilation = cfg.generator_dilation self.conv_stride = cfg.generator_stride self.pad = cfg.generator_pad self.bias = cfg.generator_bias self.dropout = torch.nn.Dropout(cfg.generator_dropout) # TODO(Dongji): batch_norm is not in cfg self.batch_norm = False self.batch_norm_weight = cfg.generator_batch_norm self.residual = cfg.generator_residual else: self.conv_kernel = conv_kernel self.conv_dilation = conv_dilation self.conv_stride = conv_stride self.output_dim = output_dim self.pad = pad self.bias = bias self.dropout = torch.nn.Dropout(dropout) self.batch_norm = batch_norm self.batch_norm_weight = batch_norm_weight self.residual = residual if self.pad < 0: self.padding = self.conv_kernel // 2 else: self.padding = self.pad self.proj = torch.nn.Sequential( TransposeLast(), torch.nn.Conv1d( input_dim, output_dim, kernel_size=self.conv_kernel, stride=self.conv_stride, dilation=self.conv_dilation, padding=self.padding, bias=self.bias, ), TransposeLast(), ) if self.batch_norm: self.bn = torch.nn.BatchNorm1d(input_dim) self.bn.weight.data.fill_(self.batch_norm_weight) if self.residual: self.in_proj = torch.nn.Linear(input_dim, input_dim)
[docs] def output_size(self): return self.output_dim
[docs] def forward( self, feats: torch.Tensor, text: Optional[torch.Tensor], feats_padding_mask: torch.Tensor, ): inter_x = None if self.batch_norm: feats = self.bn_padded_data(feats, feats_padding_mask) if self.residual: inter_x = self.in_proj(self.dropout(feats)) feats = feats + inter_x feats = self.dropout(feats) generated_sample = self.proj(feats) generated_sample_padding_mask = feats_padding_mask[:, :: self.conv_stride] if generated_sample_padding_mask.size(1) != generated_sample.size(1): new_padding = generated_sample_padding_mask.new_zeros( generated_sample.shape[:-1] ) diff = new_padding.size(1) - generated_sample_padding_mask.size(1) if diff > 0: new_padding[:, diff:] = generated_sample_padding_mask else: logging.info("ATTENTION: make sure that you are using V2 instead of V1") assert diff < 0 new_padding = generated_sample_padding_mask[:, :diff] generated_sample_padding_mask = new_padding real_sample = None if text is not None: assert torch.count_nonzero(text) > 0 real_sample = generated_sample.new_zeros(text.numel(), self.output_dim) real_sample.scatter_(1, text.view(-1, 1).long(), 1) real_sample = real_sample.view(text.shape + (self.output_dim,)) return generated_sample, real_sample, inter_x, generated_sample_padding_mask
[docs] def bn_padded_data(self, feature: torch.Tensor, padding_mask: torch.Tensor): normed_feature = feature.clone() normed_feature[~padding_mask] = self.bn( feature[~padding_mask].unsqueeze(-1) ).squeeze(-1) return normed_feature