#!/usr/bin/env python3
# 2020, Technische Universität München; Ludwig Kürzinger
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Embedding Frontend for text based inputs."""
from typing import Tuple
import torch
from typeguard import check_argument_types
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
[docs]class Embedding(AbsFrontend):
"""Embedding Frontend for text based inputs."""
def __init__(
self,
input_size: int = 400,
embed_dim: int = 400,
pos_enc_class=PositionalEncoding,
positional_dropout_rate: float = 0.1,
):
"""Initialize.
Args:
input_size: Number of input tokens.
embed_dim: Embedding Size.
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
positional_dropout_rate: dropout rate after adding positional encoding
"""
assert check_argument_types()
super().__init__()
self.embed_dim = embed_dim
# TODO(sdalmia): check for padding idx
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, embed_dim),
pos_enc_class(embed_dim, positional_dropout_rate),
)
[docs] def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply a sliding window on the input.
Args:
input: Input (B, T) or (B, T,D), with D.
input_lengths: Input lengths within batch.
Returns:
Tensor: Output with dimensions (B, T, D).
Tensor: Output lengths within batch.
"""
x = self.embed(input)
return x, input_lengths
[docs] def output_size(self) -> int:
"""Return output length of feature dimension D, i.e. the embedding dim."""
return self.embed_dim