# Copyright 2020 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Energy extractor."""
from typing import Any, Dict, Tuple, Union
import humanfriendly
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from espnet2.layers.stft import Stft
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract
from espnet.nets.pytorch_backend.nets_utils import pad_list
[docs]class Energy(AbsFeatsExtract):
"""Energy extractor."""
def __init__(
self,
fs: Union[int, str] = 22050,
n_fft: int = 1024,
win_length: int = None,
hop_length: int = 256,
window: str = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
use_token_averaged_energy: bool = True,
reduction_factor: int = None,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
self.fs = fs
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.use_token_averaged_energy = use_token_averaged_energy
if use_token_averaged_energy:
assert reduction_factor >= 1
self.reduction_factor = reduction_factor
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
)
[docs] def output_size(self) -> int:
return 1
[docs] def get_parameters(self) -> Dict[str, Any]:
return dict(
fs=self.fs,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=self.window,
win_length=self.win_length,
center=self.stft.center,
normalized=self.stft.normalized,
use_token_averaged_energy=self.use_token_averaged_energy,
reduction_factor=self.reduction_factor,
)
[docs] def forward(
self,
input: torch.Tensor,
input_lengths: torch.Tensor = None,
feats_lengths: torch.Tensor = None,
durations: torch.Tensor = None,
durations_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# If not provide, we assume that the inputs have the same length
if input_lengths is None:
input_lengths = (
input.new_ones(input.shape[0], dtype=torch.long) * input.shape[1]
)
# Domain-conversion: e.g. Stft: time -> time-freq
input_stft, energy_lengths = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
assert input_stft.shape[-1] == 2, input_stft.shape
# input_stft: (..., F, 2) -> (..., F)
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
# sum over frequency (B, N, F) -> (B, N)
energy = torch.sqrt(torch.clamp(input_power.sum(dim=2), min=1.0e-10))
# (Optional): Adjust length to match with the mel-spectrogram
if feats_lengths is not None:
energy = [
self._adjust_num_frames(e[:el].view(-1), fl)
for e, el, fl in zip(energy, energy_lengths, feats_lengths)
]
energy_lengths = feats_lengths
# (Optional): Average by duration to calculate token-wise energy
if self.use_token_averaged_energy:
durations = durations * self.reduction_factor
energy = [
self._average_by_duration(e[:el].view(-1), d)
for e, el, d in zip(energy, energy_lengths, durations)
]
energy_lengths = durations_lengths
# Padding
if isinstance(energy, list):
energy = pad_list(energy, 0.0)
# Return with the shape (B, T, 1)
return energy.unsqueeze(-1), energy_lengths
def _average_by_duration(self, x: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
assert 0 <= len(x) - d.sum() < self.reduction_factor
d_cumsum = F.pad(d.cumsum(dim=0), (1, 0))
x_avg = [
x[start:end].mean() if len(x[start:end]) != 0 else x.new_tensor(0.0)
for start, end in zip(d_cumsum[:-1], d_cumsum[1:])
]
return torch.stack(x_avg)
@staticmethod
def _adjust_num_frames(x: torch.Tensor, num_frames: torch.Tensor) -> torch.Tensor:
if num_frames > len(x):
x = F.pad(x, (0, num_frames - len(x)))
elif num_frames < len(x):
x = x[:num_frames]
return x