import torch
import torch.nn as nn
import espnet2.gan_svs.pits.modules as modules
# TODO (Yifeng): This comment is generated by ChatGPT, which may not be accurate.
[docs]class YingDecoder(nn.Module):
"""Ying decoder module."""
def __init__(
self,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
yin_start,
yin_scope,
yin_shift_range,
gin_channels=0,
):
"""Initialize the YingDecoder module.
Args:
hidden_channels (int): Number of hidden channels.
kernel_size (int): Size of the convolutional kernel.
dilation_rate (int): Dilation rate of the convolutional layers.
n_layers (int): Number of convolutional layers.
yin_start (int): Start point of the yin target signal.
yin_scope (int): Scope of the yin target signal.
yin_shift_range (int): Maximum number of frames to shift the yin
target signal.
gin_channels (int, optional): Number of global conditioning channels.
Defaults to 0.
"""
super().__init__()
self.in_channels = yin_scope
self.out_channels = yin_scope
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.yin_start = yin_start
self.yin_scope = yin_scope
self.yin_shift_range = yin_shift_range
self.pre = nn.Conv1d(self.in_channels, hidden_channels, 1)
self.dec = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, self.out_channels, 1)
[docs] def crop_scope(self, x, yin_start, scope_shift):
"""Crop the input tensor.
Args:
x (torch.Tensor): Input tensor of shape [B, C, T].
yin_start (int): Starting point of the yin target signal.
scope_shift (torch.Tensor): Shift tensor of shape [B].
Returns:
torch.Tensor: Cropped tensor of shape [B, C, yin_scope].
"""
return torch.stack(
[
x[
i,
yin_start
+ scope_shift[i] : yin_start
+ self.yin_scope
+ scope_shift[i],
:,
]
for i in range(x.shape[0])
],
dim=0,
)
[docs] def infer(self, z_yin, z_mask, g=None):
"""Generate yin prediction.
Args:
z_yin (torch.Tensor): Input yin target tensor of shape [B, yin_scope, C].
z_mask (torch.Tensor): Input mask tensor of shape [B, yin_scope, 1].
g (torch.Tensor, optional): Global conditioning tensor of shape
[B, gin_channels, 1]. Defaults to None.
Returns:
torch.Tensor: Predicted yin tensor of shape [B, yin_scope, C].
"""
B = z_yin.shape[0]
scope_shift = torch.randint(
-self.yin_shift_range, self.yin_shift_range, (B,), dtype=torch.int
)
z_yin_crop = self.crop_scope(z_yin, self.yin_start, scope_shift)
x = self.pre(z_yin_crop) * z_mask
x = self.dec(x, z_mask, g=g)
yin_hat_crop = self.proj(x) * z_mask
return yin_hat_crop
[docs] def forward(self, z_yin, yin_gt, z_mask, g=None):
"""Forward pass of the decoder.
Args:
z_yin (torch.Tensor): The input yin note sequence of shape (B, C, T_yin).
yin_gt (torch.Tensor): The ground truth yin note sequence of shape
(B, C, T_yin).
z_mask (torch.Tensor): The mask tensor of shape (B, 1, T_yin).
g (torch.Tensor): The global conditioning tensor.
Returns:
torch.Tensor: The predicted yin note sequence of shape (B, C, T_yin).
torch.Tensor: The shifted ground truth yin note sequence of shape
(B, C, T_yin).
torch.Tensor: The cropped ground truth yin note sequence of shape
(B, C, T_yin).
torch.Tensor: The cropped input yin note sequence of shape (B, C, T_yin).
torch.Tensor: The scope shift tensor of shape (B,).
"""
B = z_yin.shape[0]
scope_shift = torch.randint(
-self.yin_shift_range, self.yin_shift_range, (B,), dtype=torch.int
)
z_yin_crop = self.crop_scope(z_yin, self.yin_start, scope_shift)
yin_gt_shifted_crop = self.crop_scope(yin_gt, self.yin_start, scope_shift)
yin_gt_crop = self.crop_scope(
yin_gt, self.yin_start, torch.zeros_like(scope_shift)
)
x = self.pre(z_yin_crop) * z_mask
x = self.dec(x, z_mask, g=g)
yin_hat_crop = self.proj(x) * z_mask
return yin_gt_crop, yin_gt_shifted_crop, yin_hat_crop, z_yin_crop, scope_shift