"""Decoder definition."""
from typing import Any, List, Tuple
import torch
from typeguard import check_argument_types
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.state_spaces.model import SequenceModel
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.scorer_interface import BatchScorerInterface
[docs]class S4Decoder(AbsDecoder, BatchScorerInterface):
"""S4 decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of hidden vector
input_layer: input layer type
dropinp: input dropout
dropout: dropout parameter applied on every residual and every layer
prenorm: pre-norm vs. post-norm
n_layers: number of layers
transposed: transpose inputs so each layer receives (batch, dim, length)
tie_dropout: tie dropout mask across sequence like nn.Dropout1d/nn.Dropout2d
n_repeat: each layer is repeated n times per stage before applying pooling
layer: layer config, must be specified
residual: residual config
norm: normalization config (e.g. layer vs batch)
pool: config for pooling layer per stage
track_norms: log norms of each layer output
drop_path: drop rate for stochastic depth
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
input_layer: str = "embed",
dropinp: float = 0.0,
dropout: float = 0.25,
prenorm: bool = True,
n_layers: int = 16,
transposed: bool = False,
tie_dropout: bool = False,
n_repeat=1,
layer=None,
residual=None,
norm=None,
pool=None,
track_norms=True,
drop_path: float = 0.0,
):
assert check_argument_types()
super().__init__()
self.d_model = encoder_output_size
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.odim = vocab_size
self.dropout = dropout
if input_layer == "embed":
self.embed = torch.nn.Embedding(vocab_size, self.d_model)
else:
raise NotImplementedError
self.dropout_emb = torch.nn.Dropout(p=dropout)
self.decoder = SequenceModel(
self.d_model,
n_layers=n_layers,
transposed=transposed,
dropout=dropout,
tie_dropout=tie_dropout,
prenorm=prenorm,
n_repeat=n_repeat,
layer=layer,
residual=residual,
norm=norm,
pool=pool,
track_norms=track_norms,
dropinp=dropinp,
drop_path=drop_path,
)
self.output = torch.nn.Linear(self.d_model, vocab_size)
[docs] def init_state(self, x: torch.Tensor):
"""Initialize state."""
return self.decoder.default_state(1, device=x.device)
[docs] def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
state=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
memory = hs_pad
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device
)
emb = self.embed(ys_in_pad)
z, state = self.decoder(
emb,
state=state,
memory=memory,
lengths=ys_in_lens,
mask=memory_mask,
)
decoded = self.output(z)
return decoded, ys_in_lens
[docs] def score(self, ys, state, x):
raise NotImplementedError
[docs] def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
ys = self.embed(ys[:, -1:])
# workaround for remaining beam width of 1
if type(states[0]) is list:
states = states[0]
assert ys.size(1) == 1, ys.shape
ys = ys.squeeze(1)
ys, states = self.decoder.step(ys, state=states, memory=xs)
logp = self.output(ys).log_softmax(dim=-1)
states_list = [
[state[b].unsqueeze(0) if state is not None else None for state in states]
for b in range(n_batch)
]
return logp, states_list