Source code for espnet2.gan_svs.pits.ying_decoder

import torch
import torch.nn as nn

import espnet2.gan_svs.pits.modules as modules


# TODO (Yifeng): This comment is generated by ChatGPT, which may not be accurate.
[docs]class YingDecoder(nn.Module): """Ying decoder module.""" def __init__( self, hidden_channels, kernel_size, dilation_rate, n_layers, yin_start, yin_scope, yin_shift_range, gin_channels=0, ): """Initialize the YingDecoder module. Args: hidden_channels (int): Number of hidden channels. kernel_size (int): Size of the convolutional kernel. dilation_rate (int): Dilation rate of the convolutional layers. n_layers (int): Number of convolutional layers. yin_start (int): Start point of the yin target signal. yin_scope (int): Scope of the yin target signal. yin_shift_range (int): Maximum number of frames to shift the yin target signal. gin_channels (int, optional): Number of global conditioning channels. Defaults to 0. """ super().__init__() self.in_channels = yin_scope self.out_channels = yin_scope self.hidden_channels = hidden_channels self.kernel_size = kernel_size self.dilation_rate = dilation_rate self.n_layers = n_layers self.gin_channels = gin_channels self.yin_start = yin_start self.yin_scope = yin_scope self.yin_shift_range = yin_shift_range self.pre = nn.Conv1d(self.in_channels, hidden_channels, 1) self.dec = modules.WN( hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, ) self.proj = nn.Conv1d(hidden_channels, self.out_channels, 1)
[docs] def crop_scope(self, x, yin_start, scope_shift): """Crop the input tensor. Args: x (torch.Tensor): Input tensor of shape [B, C, T]. yin_start (int): Starting point of the yin target signal. scope_shift (torch.Tensor): Shift tensor of shape [B]. Returns: torch.Tensor: Cropped tensor of shape [B, C, yin_scope]. """ return torch.stack( [ x[ i, yin_start + scope_shift[i] : yin_start + self.yin_scope + scope_shift[i], :, ] for i in range(x.shape[0]) ], dim=0, )
[docs] def infer(self, z_yin, z_mask, g=None): """Generate yin prediction. Args: z_yin (torch.Tensor): Input yin target tensor of shape [B, yin_scope, C]. z_mask (torch.Tensor): Input mask tensor of shape [B, yin_scope, 1]. g (torch.Tensor, optional): Global conditioning tensor of shape [B, gin_channels, 1]. Defaults to None. Returns: torch.Tensor: Predicted yin tensor of shape [B, yin_scope, C]. """ B = z_yin.shape[0] scope_shift = torch.randint( -self.yin_shift_range, self.yin_shift_range, (B,), dtype=torch.int ) z_yin_crop = self.crop_scope(z_yin, self.yin_start, scope_shift) x = self.pre(z_yin_crop) * z_mask x = self.dec(x, z_mask, g=g) yin_hat_crop = self.proj(x) * z_mask return yin_hat_crop
[docs] def forward(self, z_yin, yin_gt, z_mask, g=None): """Forward pass of the decoder. Args: z_yin (torch.Tensor): The input yin note sequence of shape (B, C, T_yin). yin_gt (torch.Tensor): The ground truth yin note sequence of shape (B, C, T_yin). z_mask (torch.Tensor): The mask tensor of shape (B, 1, T_yin). g (torch.Tensor): The global conditioning tensor. Returns: torch.Tensor: The predicted yin note sequence of shape (B, C, T_yin). torch.Tensor: The shifted ground truth yin note sequence of shape (B, C, T_yin). torch.Tensor: The cropped ground truth yin note sequence of shape (B, C, T_yin). torch.Tensor: The cropped input yin note sequence of shape (B, C, T_yin). torch.Tensor: The scope shift tensor of shape (B,). """ B = z_yin.shape[0] scope_shift = torch.randint( -self.yin_shift_range, self.yin_shift_range, (B,), dtype=torch.int ) z_yin_crop = self.crop_scope(z_yin, self.yin_start, scope_shift) yin_gt_shifted_crop = self.crop_scope(yin_gt, self.yin_start, scope_shift) yin_gt_crop = self.crop_scope( yin_gt, self.yin_start, torch.zeros_like(scope_shift) ) x = self.pre(z_yin_crop) * z_mask x = self.dec(x, z_mask, g=g) yin_hat_crop = self.proj(x) * z_mask return yin_gt_crop, yin_gt_shifted_crop, yin_hat_crop, z_yin_crop, scope_shift