#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Subsampling layer definition."""
import torch
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
[docs]class TooShortUttError(Exception):
"""Raised when the utt is too short for subsampling.
Args:
message (str): Message for error catch
actual_size (int): the short size that cannot pass the subsampling
limit (int): the limit size for subsampling
"""
def __init__(self, message, actual_size, limit):
"""Construct a TooShortUttError for error handler."""
super().__init__(message)
self.actual_size = actual_size
self.limit = limit
[docs]def check_short_utt(ins, size):
"""Check if the utterance is too short for subsampling."""
if isinstance(ins, Conv1dSubsampling2) and size < 5:
return True, 5
if isinstance(ins, Conv1dSubsampling3) and size < 7:
return True, 7
if isinstance(ins, Conv2dSubsampling1) and size < 5:
return True, 5
if isinstance(ins, Conv2dSubsampling2) and size < 7:
return True, 7
if isinstance(ins, Conv2dSubsampling) and size < 7:
return True, 7
if isinstance(ins, Conv2dSubsampling6) and size < 11:
return True, 11
if isinstance(ins, Conv2dSubsampling8) and size < 15:
return True, 15
return False, -1
[docs]class Conv1dSubsampling2(torch.nn.Module):
"""Convolutional 1D subsampling (to 1/2 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv1dSubsampling2 object."""
super(Conv1dSubsampling2, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(idim, odim, 3, 1),
torch.nn.ReLU(),
torch.nn.Conv1d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim, odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
[docs] def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
"""
x = x.transpose(2, 1) # (#batch, idim, time)
x = self.conv(x)
b, c, t = x.size()
x = self.out(x.transpose(1, 2).contiguous())
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:1][:, :, :-2:2]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
[docs]class Conv1dSubsampling3(torch.nn.Module):
"""Convolutional 1D subsampling (to 1/3 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv1dSubsampling3 object."""
super(Conv1dSubsampling3, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(idim, odim, 3, 1),
torch.nn.ReLU(),
torch.nn.Conv1d(odim, odim, 5, 3),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim, odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
[docs] def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
"""
x = x.transpose(2, 1) # (#batch, idim, time)
x = self.conv(x)
b, c, t = x.size()
x = self.out(x.transpose(1, 2).contiguous())
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:1][:, :, :-4:3]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
[docs]class Conv2dSubsampling(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling object."""
super(Conv2dSubsampling, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
[docs] def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-2:2]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
[docs]class Conv2dSubsampling1(torch.nn.Module):
"""Similar to Conv2dSubsampling module, but without any subsampling performed.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling1 object."""
super(Conv2dSubsampling1, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 1),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (idim - 4), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
[docs] def forward(self, x, x_mask):
"""Pass x through 2 Conv2d layers without subsampling.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim).
where time' = time - 4.
torch.Tensor: Subsampled mask (#batch, 1, time').
where time' = time - 4.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-4]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
[docs]class Conv2dSubsampling2(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling2 object."""
super(Conv2dSubsampling2, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 1),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
[docs] def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-2:1]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
[docs]class Conv2dSubsampling6(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/6 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling6 object."""
super(Conv2dSubsampling6, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 5, 3),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
[docs] def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-4:3]
[docs]class Conv2dSubsampling8(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/8 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling8 object."""
super(Conv2dSubsampling8, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
[docs] def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]