Source code for espnet2.asr_transducer.decoder.abs_decoder

"""Abstract decoder definition for Transducer models."""

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union

import torch


[docs]class AbsDecoder(torch.nn.Module, ABC): """Abstract decoder module."""
[docs] @abstractmethod def forward(self, labels: torch.Tensor) -> torch.Tensor: """Encode source label sequences. Args: labels: Label ID sequences. Returns: : Decoder output sequences. """ raise NotImplementedError
[docs] @abstractmethod def score( self, label_sequence: List[int], states: Union[ List[Dict[str, torch.Tensor]], List[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]], ], ) -> Tuple[ torch.Tensor, Union[ List[Dict[str, torch.Tensor]], List[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]], ], ]: """One-step forward hypothesis. Args: label_sequence: Current label sequence. state: Decoder hidden states. Returns: out: Decoder output sequence. state: Decoder hidden states. """ raise NotImplementedError
[docs] @abstractmethod def batch_score( self, hyps: List[Any], ) -> Tuple[ torch.Tensor, Union[ List[Dict[str, torch.Tensor]], List[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]], ], ]: """One-step forward hypotheses. Args: hyps: Hypotheses. Returns: out: Decoder output sequences. states: Decoder hidden states. """ raise NotImplementedError
[docs] @abstractmethod def set_device(self, device: torch.Tensor) -> None: """Set GPU device to use. Args: device: Device ID. """ raise NotImplementedError
[docs] @abstractmethod def init_state( self, batch_size: int ) -> Union[ List[Dict[str, torch.Tensor]], List[torch.Tensor], Tuple[torch.Tensor, Optional[torch.tensor]], ]: """Initialize decoder states. Args: batch_size: Batch size. Returns: : Decoder hidden states. """ raise NotImplementedError
[docs] @abstractmethod def select_state( self, states: Union[ List[Dict[str, torch.Tensor]], List[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]], ], idx: int = 0, ) -> Union[ List[Dict[str, torch.Tensor]], List[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]], ]: """Get specified ID state from batch of states, if provided. Args: states: Decoder hidden states. idx: State ID to extract. Returns: : Decoder hidden state for given ID. """ raise NotImplementedError
[docs] @abstractmethod def create_batch_states( self, new_states: List[ Union[ List[Dict[str, Optional[torch.Tensor]]], List[List[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor]], ], ], ) -> Union[ List[Dict[str, torch.Tensor]], List[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]], ]: """Create batch of decoder hidden states given a list of new states. Args: new_states: Decoder hidden states. Returns: : Decoder hidden states. """ raise NotImplementedError