# Copyright 2020 Hirofumi Inaguma
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Conformer common arguments."""
[docs]def add_arguments_rnn_encoder_common(group):
"""Define common arguments for RNN encoder."""
group.add_argument(
"--etype",
default="blstmp",
type=str,
choices=[
"lstm",
"blstm",
"lstmp",
"blstmp",
"vgglstmp",
"vggblstmp",
"vgglstm",
"vggblstm",
"gru",
"bgru",
"grup",
"bgrup",
"vgggrup",
"vggbgrup",
"vgggru",
"vggbgru",
],
help="Type of encoder network architecture",
)
group.add_argument(
"--elayers",
default=4,
type=int,
help="Number of encoder layers",
)
group.add_argument(
"--eunits",
"-u",
default=300,
type=int,
help="Number of encoder hidden units",
)
group.add_argument(
"--eprojs", default=320, type=int, help="Number of encoder projection units"
)
group.add_argument(
"--subsample",
default="1",
type=str,
help="Subsample input frames x_y_z means "
"subsample every x frame at 1st layer, "
"every y frame at 2nd layer etc.",
)
return group
[docs]def add_arguments_rnn_decoder_common(group):
"""Define common arguments for RNN decoder."""
group.add_argument(
"--dtype",
default="lstm",
type=str,
choices=["lstm", "gru"],
help="Type of decoder network architecture",
)
group.add_argument(
"--dlayers", default=1, type=int, help="Number of decoder layers"
)
group.add_argument(
"--dunits", default=320, type=int, help="Number of decoder hidden units"
)
group.add_argument(
"--dropout-rate-decoder",
default=0.0,
type=float,
help="Dropout rate for the decoder",
)
group.add_argument(
"--sampling-probability",
default=0.0,
type=float,
help="Ratio of predicted labels fed back to decoder",
)
group.add_argument(
"--lsm-type",
const="",
default="",
type=str,
nargs="?",
choices=["", "unigram"],
help="Apply label smoothing with a specified distribution type",
)
return group
[docs]def add_arguments_rnn_attention_common(group):
"""Define common arguments for RNN attention."""
group.add_argument(
"--atype",
default="dot",
type=str,
choices=[
"noatt",
"dot",
"add",
"location",
"coverage",
"coverage_location",
"location2d",
"location_recurrent",
"multi_head_dot",
"multi_head_add",
"multi_head_loc",
"multi_head_multi_res_loc",
],
help="Type of attention architecture",
)
group.add_argument(
"--adim",
default=320,
type=int,
help="Number of attention transformation dimensions",
)
group.add_argument(
"--awin", default=5, type=int, help="Window size for location2d attention"
)
group.add_argument(
"--aheads",
default=4,
type=int,
help="Number of heads for multi head attention",
)
group.add_argument(
"--aconv-chans",
default=-1,
type=int,
help="Number of attention convolution channels \
(negative value indicates no location-aware attention)",
)
group.add_argument(
"--aconv-filts",
default=100,
type=int,
help="Number of attention convolution filters \
(negative value indicates no location-aware attention)",
)
group.add_argument(
"--dropout-rate",
default=0.0,
type=float,
help="Dropout rate for the encoder",
)
return group