import copy
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.specaug.specaug import SpecAug
[docs]class OpenAIWhisperEncoder(AbsEncoder):
"""Transformer-based Speech Encoder from OpenAI's Whisper Model:
URL: https://github.com/openai/whisper
"""
def __init__(
self,
input_size: int = 1,
dropout_rate: float = 0.0,
whisper_model: str = "small",
download_dir: str = None,
use_specaug: bool = False,
specaug_conf: Union[dict, None] = None,
do_pad_trim: bool = False,
):
try:
import whisper
from whisper.audio import HOP_LENGTH, N_FFT, N_MELS, N_SAMPLES
except Exception as e:
print("Error: whisper is not properly installed.")
print(
"Please install whisper with: cd ${MAIN_ROOT}/tools &&",
"./installers/install_whisper.sh",
)
raise e
assert check_argument_types()
super().__init__()
self.n_fft = N_FFT
self.win_length = N_FFT
self.hop_length = HOP_LENGTH
self.n_mels = N_MELS
self.mel_filters = whisper.audio.mel_filters
# note that originally Whisper doesn't use dropouts
self.dropout = torch.nn.Dropout(dropout_rate)
assert whisper_model in whisper.available_models()
_model = whisper.load_model(whisper_model, download_root=download_dir)
self.encoders = copy.deepcopy(_model.encoder)
self.encoders.train()
del _model
if use_specaug:
self.specaug = SpecAug(**specaug_conf)
else:
self.specaug = None
self.do_pad_trim = do_pad_trim
self.pad_samples = N_SAMPLES
[docs] def output_size(self) -> int:
return self.encoders.ln_post.normalized_shape[-1]
[docs] def pad_or_trim(
self,
array: torch.Tensor,
length: int,
axis: int = -1,
) -> torch.Tensor:
"""Pad or trim the audio array to N_SAMPLES.
Used in zero-shot inference cases.
"""
if array.shape[axis] > length:
array = array.index_select(
dim=axis, index=torch.arange(length).to(array.device)
)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
return array
[docs] def log_mel_spectrogram(
self,
audio: torch.Tensor,
ilens: torch.Tensor = None,
) -> torch.Tensor:
"""Use log-mel spectrogram computation native to Whisper training"""
window = torch.hann_window(self.win_length).to(audio.device)
stft = torch.stft(
audio, self.n_fft, self.hop_length, window=window, return_complex=True
)
# whisper deletes the last frame by default (Shih-Lun)
magnitudes = stft[..., :-1].abs() ** 2
filters = self.mel_filters(audio.device, self.n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
if ilens is not None:
olens = ilens // self.hop_length
else:
olens = None
log_spec = torch.maximum(
log_spec,
log_spec.view(audio.size(0), -1).max(dim=-1)[0][:, None, None] - 8.0,
)
log_spec = (log_spec + 4.0) / 4.0
return log_spec, olens
[docs] def whisper_encode(
self,
input: torch.Tensor,
ilens: torch.Tensor = None,
) -> torch.Tensor:
x = F.gelu(self.encoders.conv1(input))
x = F.gelu(self.encoders.conv2(x))
x = x.permute(0, 2, 1)
n_frames = x.size(1)
max_pos = self.encoders.positional_embedding.size(0)
if n_frames <= max_pos:
x = (x + self.encoders.positional_embedding[: x.size(1), :]).to(x.dtype)
else:
# due to positional encoding, audios >30 sec won't be accepted
x = x[:, :max_pos, :] + self.encoders.positional_embedding
x = self.dropout(x)
for layer, block in enumerate(self.encoders.blocks):
x = block(x)
if layer < len(self.encoders.blocks) - 1:
x = self.dropout(x)
x = self.encoders.ln_post(x)
if ilens is not None:
olens = (
1
+ (
ilens
- self.encoders.conv2.kernel_size[0]
+ 2 * self.encoders.conv2.padding[0]
)
// self.encoders.conv2.stride[0]
)
olens = torch.clamp(olens, max=max_pos)
else:
olens = None
return x, olens
[docs] def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if self.do_pad_trim:
xs_pad = self.pad_or_trim(xs_pad, self.pad_samples)
feats, feats_lens = self.log_mel_spectrogram(xs_pad, ilens)
if self.specaug is not None and self.encoders.training:
feats = torch.transpose(feats, 1, 2)
feats, feats_lens = self.specaug(feats, feats_lens)
feats = torch.transpose(feats, 1, 2)
xs_pad, olens = self.whisper_encode(feats, feats_lens)
return xs_pad, olens, None