Source code for espnet2.gan_tts.melgan.pqmf

# Copyright 2021 Tomoki Hayashi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Pseudo QMF modules.

This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.

"""

import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import kaiser


[docs]def design_prototype_filter( taps: int = 62, cutoff_ratio: float = 0.142, beta: float = 9.0 ) -> np.ndarray: """Design prototype filter for PQMF. This method is based on `A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`_. Args: taps (int): The number of filter taps. cutoff_ratio (float): Cut-off frequency ratio. beta (float): Beta coefficient for kaiser window. Returns: ndarray: Impluse response of prototype filter (taps + 1,). .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: https://ieeexplore.ieee.org/abstract/document/681427 """ # check the arguments are valid assert taps % 2 == 0, "The number of taps mush be even number." assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." # make initial filter omega_c = np.pi * cutoff_ratio with np.errstate(invalid="ignore"): h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / ( np.pi * (np.arange(taps + 1) - 0.5 * taps) ) h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form # apply kaiser window w = kaiser(taps + 1, beta) h = h_i * w return h
[docs]class PQMF(torch.nn.Module): """PQMF module. This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. .. _`Near-perfect-reconstruction pseudo-QMF banks`: https://ieeexplore.ieee.org/document/258122 """ def __init__( self, subbands: int = 4, taps: int = 62, cutoff_ratio: float = 0.142, beta: float = 9.0, ): """Initilize PQMF module. The cutoff_ratio and beta parameters are optimized for #subbands = 4. See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195. Args: subbands (int): The number of subbands. taps (int): The number of filter taps. cutoff_ratio (float): Cut-off frequency ratio. beta (float): Beta coefficient for kaiser window. """ super().__init__() # build analysis & synthesis filter coefficients h_proto = design_prototype_filter(taps, cutoff_ratio, beta) h_analysis = np.zeros((subbands, len(h_proto))) h_synthesis = np.zeros((subbands, len(h_proto))) for k in range(subbands): h_analysis[k] = ( 2 * h_proto * np.cos( (2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - (taps / 2)) + (-1) ** k * np.pi / 4 ) ) h_synthesis[k] = ( 2 * h_proto * np.cos( (2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - (taps / 2)) - (-1) ** k * np.pi / 4 ) ) # convert to tensor analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1) synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0) # register coefficients as beffer self.register_buffer("analysis_filter", analysis_filter) self.register_buffer("synthesis_filter", synthesis_filter) # filter for downsampling & upsampling updown_filter = torch.zeros((subbands, subbands, subbands)).float() for k in range(subbands): updown_filter[k, k, 0] = 1.0 self.register_buffer("updown_filter", updown_filter) self.subbands = subbands # keep padding info self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
[docs] def analysis(self, x: torch.Tensor) -> torch.Tensor: """Analysis with PQMF. Args: x (Tensor): Input tensor (B, 1, T). Returns: Tensor: Output tensor (B, subbands, T // subbands). """ x = F.conv1d(self.pad_fn(x), self.analysis_filter) return F.conv1d(x, self.updown_filter, stride=self.subbands)
[docs] def synthesis(self, x: torch.Tensor) -> torch.Tensor: """Synthesis with PQMF. Args: x (Tensor): Input tensor (B, subbands, T // subbands). Returns: Tensor: Output tensor (B, 1, T). """ # NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands. # Not sure this is the correct way, it is better to check again. # TODO(kan-bayashi): Understand the reconstruction procedure x = F.conv_transpose1d( x, self.updown_filter * self.subbands, stride=self.subbands ) return F.conv1d(self.pad_fn(x), self.synthesis_filter)