Source code for espnet2.enh.layers.complexnn

import torch
import torch.nn as nn
import torch.nn.functional as F





[docs]def complex_cat(inputs, axis): real, imag = [], [] for idx, data in enumerate(inputs): r, i = torch.chunk(data, 2, axis) real.append(r) imag.append(i) real = torch.cat(real, axis) imag = torch.cat(imag, axis) outputs = torch.cat([real, imag], axis) return outputs
[docs]class ComplexConv2d(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), dilation=1, groups=1, causal=True, complex_axis=1, ): """ComplexConv2d. in_channels: real+imag out_channels: real+imag kernel_size : input [B,C,D,T] kernel size in [D,T] padding : input [B,C,D,T] padding in [D,T] causal: if causal, will padding time dimension's left side, otherwise both """ super(ComplexConv2d, self).__init__() self.in_channels = in_channels // 2 self.out_channels = out_channels // 2 self.kernel_size = kernel_size self.stride = stride self.padding = padding self.causal = causal self.groups = groups self.dilation = dilation self.complex_axis = complex_axis self.real_conv = nn.Conv2d( self.in_channels, self.out_channels, kernel_size, self.stride, padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups, ) self.imag_conv = nn.Conv2d( self.in_channels, self.out_channels, kernel_size, self.stride, padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups, ) nn.init.normal_(self.real_conv.weight.data, std=0.05) nn.init.normal_(self.imag_conv.weight.data, std=0.05) nn.init.constant_(self.real_conv.bias, 0.0) nn.init.constant_(self.imag_conv.bias, 0.0)
[docs] def forward(self, inputs): if self.padding[1] != 0 and self.causal: inputs = F.pad(inputs, [self.padding[1], 0, 0, 0]) else: inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0]) if self.complex_axis == 0: real = self.real_conv(inputs) imag = self.imag_conv(inputs) real2real, imag2real = torch.chunk(real, 2, self.complex_axis) real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) else: if isinstance(inputs, torch.Tensor): real, imag = torch.chunk(inputs, 2, self.complex_axis) real2real = self.real_conv( real, ) imag2imag = self.imag_conv( imag, ) real2imag = self.imag_conv(real) imag2real = self.real_conv(imag) real = real2real - imag2imag imag = real2imag + imag2real out = torch.cat([real, imag], self.complex_axis) return out
[docs]class ComplexConvTranspose2d(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), output_padding=(0, 0), causal=False, complex_axis=1, groups=1, ): """ComplexConvTranspose2d. in_channels: real+imag out_channels: real+imag """ super(ComplexConvTranspose2d, self).__init__() self.in_channels = in_channels // 2 self.out_channels = out_channels // 2 self.kernel_size = kernel_size self.stride = stride self.padding = padding self.output_padding = output_padding self.groups = groups self.real_conv = nn.ConvTranspose2d( self.in_channels, self.out_channels, kernel_size, self.stride, padding=self.padding, output_padding=output_padding, groups=self.groups, ) self.imag_conv = nn.ConvTranspose2d( self.in_channels, self.out_channels, kernel_size, self.stride, padding=self.padding, output_padding=output_padding, groups=self.groups, ) self.complex_axis = complex_axis nn.init.normal_(self.real_conv.weight, std=0.05) nn.init.normal_(self.imag_conv.weight, std=0.05) nn.init.constant_(self.real_conv.bias, 0.0) nn.init.constant_(self.imag_conv.bias, 0.0)
[docs] def forward(self, inputs): if isinstance(inputs, torch.Tensor): real, imag = torch.chunk(inputs, 2, self.complex_axis) elif isinstance(inputs, tuple) or isinstance(inputs, list): real = inputs[0] imag = inputs[1] if self.complex_axis == 0: real = self.real_conv(inputs) imag = self.imag_conv(inputs) real2real, imag2real = torch.chunk(real, 2, self.complex_axis) real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) else: if isinstance(inputs, torch.Tensor): real, imag = torch.chunk(inputs, 2, self.complex_axis) real2real = self.real_conv( real, ) imag2imag = self.imag_conv( imag, ) real2imag = self.imag_conv(real) imag2real = self.real_conv(imag) real = real2real - imag2imag imag = real2imag + imag2real out = torch.cat([real, imag], self.complex_axis) return out
[docs]class ComplexBatchNorm(torch.nn.Module): def __init__( self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, complex_axis=1, ): super(ComplexBatchNorm, self).__init__() self.num_features = num_features // 2 self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats self.complex_axis = complex_axis if self.affine: self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features)) self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features)) self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features)) self.Br = torch.nn.Parameter(torch.Tensor(self.num_features)) self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features)) else: self.register_parameter("Wrr", None) self.register_parameter("Wri", None) self.register_parameter("Wii", None) self.register_parameter("Br", None) self.register_parameter("Bi", None) if self.track_running_stats: self.register_buffer("RMr", torch.zeros(self.num_features)) self.register_buffer("RMi", torch.zeros(self.num_features)) self.register_buffer("RVrr", torch.ones(self.num_features)) self.register_buffer("RVri", torch.zeros(self.num_features)) self.register_buffer("RVii", torch.ones(self.num_features)) self.register_buffer( "num_batches_tracked", torch.tensor(0, dtype=torch.long) ) else: self.register_parameter("RMr", None) self.register_parameter("RMi", None) self.register_parameter("RVrr", None) self.register_parameter("RVri", None) self.register_parameter("RVii", None) self.register_parameter("num_batches_tracked", None) self.reset_parameters()
[docs] def reset_running_stats(self): if self.track_running_stats: self.RMr.zero_() self.RMi.zero_() self.RVrr.fill_(1) self.RVri.zero_() self.RVii.fill_(1) self.num_batches_tracked.zero_()
[docs] def reset_parameters(self): self.reset_running_stats() if self.affine: self.Br.data.zero_() self.Bi.data.zero_() self.Wrr.data.fill_(1) self.Wri.data.uniform_(-0.9, +0.9) # W will be positive-definite self.Wii.data.fill_(1)
def _check_input_dim(self, xr, xi): assert xr.shape == xi.shape assert xr.size(1) == self.num_features
[docs] def forward(self, inputs): xr, xi = torch.chunk(inputs, 2, axis=self.complex_axis) exponential_average_factor = 0.0 if self.training and self.track_running_stats: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / self.num_batches_tracked.item() else: # use exponential moving average exponential_average_factor = self.momentum # NOTE: The precise meaning of the "training flag" is: # True: Normalize using batch statistics, update running statistics # if they are being collected. # False: Normalize using running statistics, ignore batch statistics. training = self.training or not self.track_running_stats redux = [i for i in reversed(range(xr.dim())) if i != 1] vdim = [1] * xr.dim() vdim[1] = xr.size(1) # Mean M Computation and Centering # Includes running mean update if training and running. if training: Mr, Mi = xr, xi for d in redux: Mr = Mr.mean(d, keepdim=True) Mi = Mi.mean(d, keepdim=True) if self.track_running_stats: self.RMr.lerp_(Mr.squeeze(), exponential_average_factor) self.RMi.lerp_(Mi.squeeze(), exponential_average_factor) else: Mr = self.RMr.view(vdim) Mi = self.RMi.view(vdim) xr, xi = xr - Mr, xi - Mi # Variance Matrix V Computation # Includes epsilon numerical stabilizer/Tikhonov regularizer. # Includes running variance update if training and running. if training: Vrr = xr * xr Vri = xr * xi Vii = xi * xi for d in redux: Vrr = Vrr.mean(d, keepdim=True) Vri = Vri.mean(d, keepdim=True) Vii = Vii.mean(d, keepdim=True) if self.track_running_stats: self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor) self.RVri.lerp_(Vri.squeeze(), exponential_average_factor) self.RVii.lerp_(Vii.squeeze(), exponential_average_factor) else: Vrr = self.RVrr.view(vdim) Vri = self.RVri.view(vdim) Vii = self.RVii.view(vdim) Vrr = Vrr + self.eps Vri = Vri Vii = Vii + self.eps # Matrix Inverse Square Root U = V^-0.5 # sqrt of a 2x2 matrix, # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix tau = Vrr + Vii delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri) s = delta.sqrt() t = (tau + 2 * s).sqrt() # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html rst = (s * t).reciprocal() Urr = (s + Vii) * rst Uii = (s + Vrr) * rst Uri = (-Vri) * rst # Optionally left-multiply U by affine weights W to produce combined # weights Z, left-multiply the inputs by Z, then optionally bias them. # # y = Zx + B # y = WUx + B # y = [Wrr Wri][Urr Uri] [xr] + [Br] # [Wir Wii][Uir Uii] [xi] [Bi] if self.affine: Wrr, Wri, Wii = ( self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim), ) Zrr = (Wrr * Urr) + (Wri * Uri) Zri = (Wrr * Uri) + (Wri * Uii) Zir = (Wri * Urr) + (Wii * Uri) Zii = (Wri * Uri) + (Wii * Uii) else: Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii yr = (Zrr * xr) + (Zri * xi) yi = (Zir * xr) + (Zii * xi) if self.affine: yr = yr + self.Br.view(vdim) yi = yi + self.Bi.view(vdim) outputs = torch.cat([yr, yi], self.complex_axis) return outputs
[docs] def extra_repr(self): return ( "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " "track_running_stats={track_running_stats}".format(**self.__dict__) )