Source code for espnet2.gan_svs.vits.modules

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2022 Yifeng Yu
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

import torch


[docs]class Projection(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.hidden_channels = hidden_channels self.out_channels = out_channels self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
[docs] def forward(self, x, x_mask): # x shape: (B, attention_dim, T_text) stats = self.proj(x) * x_mask m_p, logs_p = torch.split(stats, self.out_channels, dim=1) return m_p, logs_p
[docs]def sequence_mask(length, max_length=None): if max_length is None: max_length = length.max() x = torch.arange(max_length, dtype=length.dtype, device=length.device) return x.unsqueeze(0) < length.unsqueeze(1)