from abc import ABC, abstractmethod
import torch
[docs]class AbsPostDecoder(torch.nn.Module, ABC):
[docs] @abstractmethod
def output_size(self) -> int:
raise NotImplementedError
[docs] @abstractmethod
def forward(
self,
transcript_input_ids: torch.LongTensor,
transcript_attention_mask: torch.LongTensor,
transcript_token_type_ids: torch.LongTensor,
transcript_position_ids: torch.LongTensor,
) -> torch.Tensor:
raise NotImplementedError
[docs] @abstractmethod
def convert_examples_to_features(
self, data: list, max_seq_length: int, output_size: int
):
raise NotImplementedError