"""Sequential implementation of Recurrent Neural Network Language Model."""
from typing import Tuple, Union
import torch
import torch.nn as nn
from typeguard import check_argument_types
from espnet2.lm.abs_model import AbsLM
[docs]class SequentialRNNLM(AbsLM):
    """Sequential RNNLM.
    See also:
        https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py
    """
    def __init__(
        self,
        vocab_size: int,
        unit: int = 650,
        nhid: int = None,
        nlayers: int = 2,
        dropout_rate: float = 0.0,
        tie_weights: bool = False,
        rnn_type: str = "lstm",
        ignore_id: int = 0,
    ):
        assert check_argument_types()
        super().__init__()
        ninp = unit
        if nhid is None:
            nhid = unit
        rnn_type = rnn_type.upper()
        self.drop = nn.Dropout(dropout_rate)
        self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id)
        if rnn_type in ["LSTM", "GRU"]:
            rnn_class = getattr(nn, rnn_type)
            self.rnn = rnn_class(
                ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True
            )
        else:
            try:
                nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
            except KeyError:
                raise ValueError(
                    """An invalid option for `--model` was supplied,
                    options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']"""
                )
            self.rnn = nn.RNN(
                ninp,
                nhid,
                nlayers,
                nonlinearity=nonlinearity,
                dropout=dropout_rate,
                batch_first=True,
            )
        self.decoder = nn.Linear(nhid, vocab_size)
        # Optionally tie weights as in:
        # "Using the Output Embedding to Improve Language Models"
        # (Press & Wolf 2016) https://arxiv.org/abs/1608.05859
        # and
        # "Tying Word Vectors and Word Classifiers:
        # A Loss Framework for Language Modeling" (Inan et al. 2016)
        # https://arxiv.org/abs/1611.01462
        if tie_weights:
            if nhid != ninp:
                raise ValueError(
                    "When using the tied flag, nhid must be equal to emsize"
                )
            self.decoder.weight = self.encoder.weight
        self.rnn_type = rnn_type
        self.nhid = nhid
        self.nlayers = nlayers
[docs]    def zero_state(self):
        """Initialize LM state filled with zero values."""
        if isinstance(self.rnn, torch.nn.LSTM):
            h = torch.zeros((self.nlayers, self.nhid), dtype=torch.float)
            c = torch.zeros((self.nlayers, self.nhid), dtype=torch.float)
            state = h, c
        else:
            state = torch.zeros((self.nlayers, self.nhid), dtype=torch.float)
        return state 
[docs]    def forward(
        self, input: torch.Tensor, hidden: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(
            output.contiguous().view(output.size(0) * output.size(1), output.size(2))
        )
        return (
            decoded.view(output.size(0), output.size(1), decoded.size(1)),
            hidden,
        ) 
[docs]    def score(
        self,
        y: torch.Tensor,
        state: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
        x: torch.Tensor,
    ) -> Tuple[torch.Tensor, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
        """Score new token.
        Args:
            y: 1D torch.int64 prefix tokens.
            state: Scorer state for prefix tokens
            x: 2D encoder feature that generates ys.
        Returns:
            Tuple of
                torch.float32 scores for next token (n_vocab)
                and next state for ys
        """
        y, new_state = self(y[-1].view(1, 1), state)
        logp = y.log_softmax(dim=-1).view(-1)
        return logp, new_state 
[docs]    def batch_score(
        self, ys: torch.Tensor, states: torch.Tensor, xs: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Score new token batch.
        Args:
            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
            states (List[Any]): Scorer states for prefix tokens.
            xs (torch.Tensor):
                The encoder feature that generates ys (n_batch, xlen, n_feat).
        Returns:
            tuple[torch.Tensor, List[Any]]: Tuple of
                batchfied scores for next token with shape of `(n_batch, n_vocab)`
                and next state list for ys.
        """
        if states[0] is None:
            states = None
        elif isinstance(self.rnn, torch.nn.LSTM):
            # states: Batch x 2 x (Nlayers, Dim) -> 2 x (Nlayers, Batch, Dim)
            h = torch.stack([h for h, c in states], dim=1)
            c = torch.stack([c for h, c in states], dim=1)
            states = h, c
        else:
            # states: Batch x (Nlayers, Dim) -> (Nlayers, Batch, Dim)
            states = torch.stack(states, dim=1)
        ys, states = self(ys[:, -1:], states)
        # ys: (Batch, 1, Nvocab) -> (Batch, NVocab)
        assert ys.size(1) == 1, ys.shape
        ys = ys.squeeze(1)
        logp = ys.log_softmax(dim=-1)
        # state: Change to batch first
        if isinstance(self.rnn, torch.nn.LSTM):
            # h, c: (Nlayers, Batch, Dim)
            h, c = states
            # states: Batch x 2 x (Nlayers, Dim)
            states = [(h[:, i], c[:, i]) for i in range(h.size(1))]
        else:
            # states: (Nlayers, Batch, Dim) -> Batch x (Nlayers, Dim)
            states = [states[:, i] for i in range(states.size(1))]
        return logp, states