Source code for espnet2.asr_transducer.activation

"""Activation functions for Transducer models."""

import torch
from packaging.version import parse as V


[docs]def get_activation( activation_type: str, ftswish_threshold: float = -0.2, ftswish_mean_shift: float = 0.0, hardtanh_min_val: int = -1.0, hardtanh_max_val: int = 1.0, leakyrelu_neg_slope: float = 0.01, smish_alpha: float = 1.0, smish_beta: float = 1.0, softplus_beta: float = 1.0, softplus_threshold: int = 20, swish_beta: float = 1.0, ) -> torch.nn.Module: """Return activation function. Args: activation_type: Activation function type. ftswish_threshold: Threshold value for FTSwish activation formulation. ftswish_mean_shift: Mean shifting value for FTSwish activation formulation. hardtanh_min_val: Minimum value of the linear region range for HardTanh. hardtanh_max_val: Maximum value of the linear region range for HardTanh. leakyrelu_neg_slope: Negative slope value for LeakyReLU activation formulation. smish_alpha: Alpha value for Smish activation fomulation. smish_beta: Beta value for Smish activation formulation. softplus_beta: Beta value for softplus activation formulation in Mish. softplus_threshold: Values above this revert to a linear function in Mish. swish_beta: Beta value for Swish variant formulation. Returns: : Activation function. """ torch_version = V(torch.__version__) activations = { "ftswish": ( FTSwish, {"threshold": ftswish_threshold, "mean_shift": ftswish_mean_shift}, ), "hardtanh": ( torch.nn.Hardtanh, {"min_val": hardtanh_min_val, "max_val": hardtanh_max_val}, ), "leaky_relu": (torch.nn.LeakyReLU, {"negative_slope": leakyrelu_neg_slope}), "mish": ( Mish, { "softplus_beta": softplus_beta, "softplus_threshold": softplus_threshold, "use_builtin": torch_version >= V("1.9"), }, ), "relu": (torch.nn.ReLU, {}), "selu": (torch.nn.SELU, {}), "smish": (Smish, {"alpha": smish_alpha, "beta": smish_beta}), "swish": ( Swish, {"beta": swish_beta, "use_builtin": torch_version >= V("1.8")}, ), "tanh": (torch.nn.Tanh, {}), "identity": (torch.nn.Identity, {}), } act_func, act_args = activations[activation_type] return act_func(**act_args)
[docs]class FTSwish(torch.nn.Module): """Flatten-T Swish activation definition. FTSwish(x) = x * sigmoid(x) + threshold where FTSwish(x) < 0 = threshold Reference: https://arxiv.org/abs/1812.06247 Args: threshold: Threshold value for FTSwish activation formulation. (threshold < 0) mean_shift: Mean shifting value for FTSwish activation formulation. (applied only if != 0, disabled by default) """ def __init__(self, threshold: float = -0.2, mean_shift: float = 0) -> None: super().__init__() assert threshold < 0, "FTSwish threshold parameter should be < 0." self.threshold = threshold self.mean_shift = mean_shift
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward computation.""" x = (x * torch.sigmoid(x)) + self.threshold x = torch.where(x >= 0, x, torch.tensor([self.threshold], device=x.device)) if self.mean_shift != 0: x.sub_(self.mean_shift) return x
[docs]class Mish(torch.nn.Module): """Mish activation definition. Mish(x) = x * tanh(softplus(x)) Reference: https://arxiv.org/abs/1908.08681. Args: softplus_beta: Beta value for softplus activation formulation. (Usually 0 > softplus_beta >= 2) softplus_threshold: Values above this revert to a linear function. (Usually 10 > softplus_threshold >= 20) use_builtin: Whether to use PyTorch activation function if available. """ def __init__( self, softplus_beta: float = 1.0, softplus_threshold: int = 20, use_builtin: bool = False, ) -> None: super().__init__() if use_builtin: self.mish = torch.nn.Mish() else: self.tanh = torch.nn.Tanh() self.softplus = torch.nn.Softplus( beta=softplus_beta, threshold=softplus_threshold ) self.mish = lambda x: x * self.tanh(self.softplus(x))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward computation.""" return self.mish(x)
[docs]class Smish(torch.nn.Module): """Smish activation definition. Smish(x) = (alpha * x) * tanh(log(1 + sigmoid(beta * x))) where alpha > 0 and beta > 0 Reference: https://www.mdpi.com/2079-9292/11/4/540/htm. Args: alpha: Alpha value for Smish activation fomulation. (Usually, alpha = 1. If alpha <= 0, set value to 1). beta: Beta value for Smish activation formulation. (Usually, beta = 1. If beta <= 0, set value to 1). """ def __init__(self, alpha: float = 1.0, beta: float = 1.0) -> None: super().__init__() self.tanh = torch.nn.Tanh() self.alpha = alpha if alpha > 0 else 1 self.beta = beta if beta > 0 else 1 self.smish = lambda x: (self.alpha * x) * self.tanh( torch.log(1 + torch.sigmoid((self.beta * x))) )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward computation.""" return self.smish(x)
[docs]class Swish(torch.nn.Module): """Swish activation definition. Swish(x) = (beta * x) * sigmoid(x) where beta = 1 defines standard Swish activation. References: https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1. E-swish variant: https://arxiv.org/abs/1801.07145. Args: beta: Beta parameter for E-Swish. (beta >= 1. If beta < 1, use standard Swish). use_builtin: Whether to use PyTorch function if available. """ def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None: super().__init__() self.beta = beta if beta > 1: self.swish = lambda x: (self.beta * x) * torch.sigmoid(x) else: if use_builtin: self.swish = torch.nn.SiLU() else: self.swish = lambda x: x * torch.sigmoid(x)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward computation.""" return self.swish(x)