Source code for espnet2.gan_tts.vits.monotonic_align.__init__

"""Maximum path calculation module.

This code is based on https://github.com/jaywalnut310/vits.

"""

import warnings

import numpy as np
import torch
from numba import njit, prange

try:
    from .core import maximum_path_c

    is_cython_avalable = True
except ImportError:
    is_cython_avalable = False
    warnings.warn(
        "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
        "If you want to use the cython version, please build it as follows: "
        "`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`"
    )


[docs]def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: """Calculate maximum path. Args: neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text). attn_mask (Tensor): Attention mask (B, T_feats, T_text). Returns: Tensor: Maximum path tensor (B, T_feats, T_text). """ device, dtype = neg_x_ent.device, neg_x_ent.dtype neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32) path = np.zeros(neg_x_ent.shape, dtype=np.int32) t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32) t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32) if is_cython_avalable: maximum_path_c(path, neg_x_ent, t_t_max, t_s_max) else: maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max) return torch.from_numpy(path).to(device=device, dtype=dtype)
[docs]@njit def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf): """Calculate a single maximum path with numba.""" index = t_x - 1 for y in range(t_y): for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): if x == y: v_cur = max_neg_val else: v_cur = value[y - 1, x] if x == 0: if y == 0: v_prev = 0.0 else: v_prev = max_neg_val else: v_prev = value[y - 1, x - 1] value[y, x] += max(v_prev, v_cur) for y in range(t_y - 1, -1, -1): path[y, index] = 1 if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): index = index - 1
[docs]@njit(parallel=True) def maximum_path_numba(paths, values, t_ys, t_xs): """Calculate batch maximum path with numba.""" for i in prange(paths.shape[0]): maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])