Source code for espnet2.asr_transducer.encoder.building
"""Set of methods to build Transducer encoder architecture."""
from typing import Any, Dict, List, Optional, Union
from espnet2.asr_transducer.activation import get_activation
from espnet2.asr_transducer.encoder.blocks.branchformer import Branchformer
from espnet2.asr_transducer.encoder.blocks.conformer import Conformer
from espnet2.asr_transducer.encoder.blocks.conv1d import Conv1d
from espnet2.asr_transducer.encoder.blocks.conv_input import ConvInput
from espnet2.asr_transducer.encoder.blocks.ebranchformer import EBranchformer
from espnet2.asr_transducer.encoder.modules.attention import ( # noqa: H301
RelPositionMultiHeadedAttention,
)
from espnet2.asr_transducer.encoder.modules.convolution import ( # noqa: H301
ConformerConvolution,
ConvolutionalSpatialGatingUnit,
DepthwiseConvolution,
)
from espnet2.asr_transducer.encoder.modules.multi_blocks import MultiBlocks
from espnet2.asr_transducer.encoder.modules.positional_encoding import ( # noqa: H301
RelPositionalEncoding,
)
from espnet2.asr_transducer.normalization import get_normalization
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
PositionwiseFeedForward,
)
[docs]def build_main_parameters(
pos_wise_act_type: str = "swish",
conv_mod_act_type: str = "swish",
pos_enc_dropout_rate: float = 0.0,
pos_enc_max_len: int = 5000,
simplified_att_score: bool = False,
norm_type: str = "layer_norm",
conv_mod_norm_type: str = "layer_norm",
after_norm_eps: Optional[float] = None,
after_norm_partial: Optional[float] = None,
blockdrop_rate: float = 0.0,
dynamic_chunk_training: bool = False,
short_chunk_threshold: float = 0.75,
short_chunk_size: int = 25,
num_left_chunks: int = 0,
**activation_parameters,
) -> Dict[str, Any]:
"""Build encoder main parameters.
Args:
pos_wise_act_type: X-former position-wise feed-forward activation type.
conv_mod_act_type: X-former convolution module activation type.
pos_enc_dropout_rate: Positional encoding dropout rate.
pos_enc_max_len: Positional encoding maximum length.
simplified_att_score: Whether to use simplified attention score computation.
norm_type: X-former normalization module type.
conv_mod_norm_type: Conformer convolution module normalization type.
after_norm_eps: Epsilon value for the final normalization.
after_norm_partial: Value for the final normalization with RMSNorm.
blockdrop_rate: Probability threshold of dropping out each encoder block.
dynamic_chunk_training: Whether to use dynamic chunk training.
short_chunk_threshold: Threshold for dynamic chunk selection.
short_chunk_size: Minimum number of frames during dynamic chunk training.
num_left_chunks: Number of left chunks the attention module can see.
(null or negative value means full context)
**activations_parameters: Parameters of the activation functions.
(See espnet2/asr_transducer/activation.py)
Returns:
: Main encoder parameters
"""
main_params = {}
main_params["pos_wise_act"] = get_activation(
pos_wise_act_type, **activation_parameters
)
main_params["conv_mod_act"] = get_activation(
conv_mod_act_type, **activation_parameters
)
main_params["pos_enc_dropout_rate"] = pos_enc_dropout_rate
main_params["pos_enc_max_len"] = pos_enc_max_len
main_params["simplified_att_score"] = simplified_att_score
main_params["norm_type"] = norm_type
main_params["conv_mod_norm_type"] = conv_mod_norm_type
(
main_params["after_norm_class"],
main_params["after_norm_args"],
) = get_normalization(norm_type, eps=after_norm_eps, partial=after_norm_partial)
main_params["blockdrop_rate"] = blockdrop_rate
main_params["dynamic_chunk_training"] = dynamic_chunk_training
main_params["short_chunk_threshold"] = max(0, short_chunk_threshold)
main_params["short_chunk_size"] = max(0, short_chunk_size)
main_params["num_left_chunks"] = max(0, num_left_chunks)
return main_params
[docs]def build_positional_encoding(
block_size: int, configuration: Dict[str, Any]
) -> RelPositionalEncoding:
"""Build positional encoding block.
Args:
block_size: Input/output size.
configuration: Positional encoding configuration.
Returns:
: Positional encoding module.
"""
return RelPositionalEncoding(
block_size,
configuration.get("pos_enc_dropout_rate", 0.0),
max_len=configuration.get("pos_enc_max_len", 5000),
)
[docs]def build_input_block(
input_size: int,
configuration: Dict[str, Union[str, int]],
) -> ConvInput:
"""Build encoder input block.
Args:
input_size: Input size.
configuration: Input block configuration.
Returns:
: ConvInput block function.
"""
return ConvInput(
input_size,
configuration["conv_size"],
configuration["subsampling_factor"],
vgg_like=configuration["vgg_like"],
output_size=configuration["output_size"],
)
[docs]def build_branchformer_block(
configuration: List[Dict[str, Any]],
main_params: Dict[str, Any],
) -> Branchformer:
"""Build Branchformer block.
Args:
configuration: Branchformer block configuration.
main_params: Encoder main parameters.
Returns:
: Branchformer block function.
"""
hidden_size = configuration["hidden_size"]
linear_size = configuration["linear_size"]
dropout_rate = configuration.get("dropout_rate", 0.0)
conv_mod_norm_class, conv_mod_norm_args = get_normalization(
main_params["conv_mod_norm_type"],
eps=configuration.get("conv_mod_norm_eps"),
partial=configuration.get("conv_mod_norm_partial"),
)
conv_mod_args = (
linear_size,
configuration["conv_mod_kernel_size"],
conv_mod_norm_class,
conv_mod_norm_args,
dropout_rate,
main_params["dynamic_chunk_training"],
)
mult_att_args = (
configuration.get("heads", 4),
hidden_size,
configuration.get("att_dropout_rate", 0.0),
main_params["simplified_att_score"],
)
norm_class, norm_args = get_normalization(
main_params["norm_type"],
eps=configuration.get("norm_eps"),
partial=configuration.get("norm_partial"),
)
return lambda: Branchformer(
hidden_size,
linear_size,
RelPositionMultiHeadedAttention(*mult_att_args),
ConvolutionalSpatialGatingUnit(*conv_mod_args),
norm_class=norm_class,
norm_args=norm_args,
dropout_rate=dropout_rate,
)
[docs]def build_conformer_block(
configuration: List[Dict[str, Any]],
main_params: Dict[str, Any],
) -> Conformer:
"""Build Conformer block.
Args:
configuration: Conformer block configuration.
main_params: Encoder main parameters.
Returns:
: Conformer block function.
"""
hidden_size = configuration["hidden_size"]
linear_size = configuration["linear_size"]
pos_wise_args = (
hidden_size,
linear_size,
configuration.get("pos_wise_dropout_rate", 0.0),
main_params["pos_wise_act"],
)
conv_mod_norm_args = {
"eps": configuration.get("conv_mod_norm_eps", 1e-05),
"momentum": configuration.get("conv_mod_norm_momentum", 0.1),
}
conv_mod_args = (
hidden_size,
configuration["conv_mod_kernel_size"],
main_params["conv_mod_act"],
conv_mod_norm_args,
main_params["dynamic_chunk_training"],
)
mult_att_args = (
configuration.get("heads", 4),
hidden_size,
configuration.get("att_dropout_rate", 0.0),
main_params["simplified_att_score"],
)
norm_class, norm_args = get_normalization(
main_params["norm_type"],
eps=configuration.get("norm_eps"),
partial=configuration.get("norm_partial"),
)
return lambda: Conformer(
hidden_size,
RelPositionMultiHeadedAttention(*mult_att_args),
PositionwiseFeedForward(*pos_wise_args),
PositionwiseFeedForward(*pos_wise_args),
ConformerConvolution(*conv_mod_args),
norm_class=norm_class,
norm_args=norm_args,
dropout_rate=configuration.get("dropout_rate", 0.0),
)
[docs]def build_conv1d_block(
configuration: List[Dict[str, Any]],
causal: bool,
) -> Conv1d:
"""Build Conv1d block.
Args:
configuration: Conv1d block configuration.
Returns:
: Conv1d block function.
"""
return lambda: Conv1d(
configuration["input_size"],
configuration["output_size"],
configuration["kernel_size"],
stride=configuration.get("stride", 1),
dilation=configuration.get("dilation", 1),
groups=configuration.get("groups", 1),
bias=configuration.get("bias", True),
relu=configuration.get("relu", True),
batch_norm=configuration.get("batch_norm", False),
causal=causal,
dropout_rate=configuration.get("dropout_rate", 0.0),
)
[docs]def build_ebranchformer_block(
configuration: List[Dict[str, Any]],
main_params: Dict[str, Any],
) -> EBranchformer:
"""Build E-Branchformer block.
Args:
configuration: E-Branchformer block configuration.
main_params: Encoder main parameters.
Returns:
: E-Branchformer block function.
"""
hidden_size = configuration["hidden_size"]
linear_size = configuration["linear_size"]
dropout_rate = configuration.get("dropout_rate", 0.0)
pos_wise_args = (
hidden_size,
linear_size,
configuration.get("pos_wise_dropout_rate", 0.0),
main_params["pos_wise_act"],
)
conv_mod_norm_class, conv_mod_norm_args = get_normalization(
main_params["conv_mod_norm_type"],
eps=configuration.get("conv_mod_norm_eps"),
partial=configuration.get("conv_mod_norm_partial"),
)
conv_mod_args = (
linear_size,
configuration["conv_mod_kernel_size"],
conv_mod_norm_class,
conv_mod_norm_args,
dropout_rate,
main_params["dynamic_chunk_training"],
)
mult_att_args = (
configuration.get("heads", 4),
hidden_size,
configuration.get("att_dropout_rate", 0.0),
main_params["simplified_att_score"],
)
depthwise_conv_args = (
hidden_size,
configuration.get(
"depth_conv_kernel_size", configuration["conv_mod_kernel_size"]
),
main_params["dynamic_chunk_training"],
)
norm_class, norm_args = get_normalization(
main_params["norm_type"],
eps=configuration.get("norm_eps"),
partial=configuration.get("norm_partial"),
)
return lambda: EBranchformer(
hidden_size,
linear_size,
RelPositionMultiHeadedAttention(*mult_att_args),
PositionwiseFeedForward(*pos_wise_args),
PositionwiseFeedForward(*pos_wise_args),
ConvolutionalSpatialGatingUnit(*conv_mod_args),
DepthwiseConvolution(*depthwise_conv_args),
norm_class=norm_class,
norm_args=norm_args,
dropout_rate=dropout_rate,
)
[docs]def build_body_blocks(
configuration: List[Dict[str, Any]],
main_params: Dict[str, Any],
output_size: int,
) -> MultiBlocks:
"""Build encoder body blocks.
Args:
configuration: Body blocks configuration.
main_params: Encoder main parameters.
output_size: Architecture output size.
Returns:
MultiBlocks function encapsulation all encoder blocks.
"""
fn_modules = []
extended_conf = []
for c in configuration:
if c.get("num_blocks") is not None:
extended_conf += c["num_blocks"] * [
{c_i: c[c_i] for c_i in c if c_i != "num_blocks"}
]
else:
extended_conf += [c]
for i, c in enumerate(extended_conf):
block_type = c["block_type"]
if block_type == "branchformer":
module = build_branchformer_block(c, main_params)
elif block_type == "conformer":
module = build_conformer_block(c, main_params)
elif block_type == "conv1d":
module = build_conv1d_block(c, main_params["dynamic_chunk_training"])
elif block_type == "ebranchformer":
module = build_ebranchformer_block(c, main_params)
else:
raise NotImplementedError
fn_modules.append(module)
return MultiBlocks(
[fn() for fn in fn_modules],
output_size,
norm_class=main_params["after_norm_class"],
norm_args=main_params["after_norm_args"],
blockdrop_rate=main_params["blockdrop_rate"],
)