import contextlib
from typing import Tuple
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from espnet2.asr.frontend.abs_frontend import AbsFrontend
[docs]class WhisperFrontend(AbsFrontend):
"""Speech Representation Using Encoder Outputs from OpenAI's Whisper Model:
URL: https://github.com/openai/whisper
"""
def __init__(
self,
whisper_model: str = "small",
freeze_weights: bool = True,
download_dir: str = None,
):
try:
import whisper
from whisper.audio import HOP_LENGTH, N_FFT, N_MELS
except Exception as e:
print("Error: whisper is not properly installed.")
print(
"Please install whisper with: cd ${MAIN_ROOT}/tools && "
"./installers/install_whisper.sh"
)
raise e
assert check_argument_types()
super().__init__()
self.n_fft = N_FFT
self.win_length = N_FFT
self.hop_length = HOP_LENGTH
self.n_mels = N_MELS
self.mel_filters = whisper.audio.mel_filters
self.pad_or_trim = whisper.pad_or_trim
assert whisper_model in whisper.available_models()
self.whisper = whisper.load_model(whisper_model, download_root=download_dir)
self.whisper.eval()
self.freeze_weights = freeze_weights
[docs] def output_size(self) -> int:
return self.whisper.encoder.ln_post.normalized_shape[-1]
[docs] def log_mel_spectrogram(
self,
audio: torch.Tensor,
ilens: torch.Tensor = None,
) -> torch.Tensor:
window = torch.hann_window(self.win_length).to(audio.device)
stft = torch.stft(
audio, self.n_fft, self.hop_length, window=window, return_complex=True
)
# whisper deletes the last frame by default (Shih-Lun)
magnitudes = stft[..., :-1].abs() ** 2
filters = self.mel_filters(audio.device, self.n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
if ilens is not None:
olens = ilens // self.hop_length
else:
olens = None
log_spec = torch.maximum(
log_spec,
log_spec.view(audio.size(0), -1).max(dim=-1)[0][:, None, None] - 8.0,
)
log_spec = (log_spec + 4.0) / 4.0
return log_spec, olens
[docs] def whisper_encode(
self,
input: torch.Tensor,
ilens: torch.Tensor = None,
) -> torch.Tensor:
whisper_encoder = self.whisper.encoder
x = F.gelu(whisper_encoder.conv1(input))
x = F.gelu(whisper_encoder.conv2(x))
x = x.permute(0, 2, 1)
n_frames = x.size(1)
max_pos = whisper_encoder.positional_embedding.size(0)
if n_frames <= max_pos:
x = (x + whisper_encoder.positional_embedding[: x.size(1), :]).to(x.dtype)
else:
x = x[:, :max_pos, :] + whisper_encoder.positional_embedding
for block in whisper_encoder.blocks:
x = block(x)
x = whisper_encoder.ln_post(x)
if ilens is not None:
olens = (
1
+ (
ilens
- whisper_encoder.conv2.kernel_size[0]
+ 2 * whisper_encoder.conv2.padding[0]
)
// whisper_encoder.conv2.stride[0]
)
olens = torch.clamp(olens, max=max_pos)
else:
olens = None
return x, olens
[docs] def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
feats, feats_lens = self.log_mel_spectrogram(input, input_lengths)
with torch.no_grad() if self.freeze_weights else contextlib.nullcontext():
feats, feats_lens = self.whisper_encode(feats, feats_lens)
return feats, feats_lens