Source code for espnet2.asr.decoder.hugging_face_transformers_decoder

#!/usr/bin/env python3
#  2022, University of Stuttgart;  Pavel Denisov
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Hugging Face Transformers Decoder."""

import copy
import logging
from typing import Tuple

import torch
from typeguard import check_argument_types

from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask

try:
    from transformers import AutoModelForSeq2SeqLM

    is_transformers_available = True
except ImportError:
    is_transformers_available = False


[docs]class HuggingFaceTransformersDecoder(AbsDecoder): """Hugging Face Transformers Decoder. Args: encoder_output_size: dimension of encoder attention model_name_or_path: Hugging Face Transformers model name """ def __init__( self, vocab_size: int, encoder_output_size: int, model_name_or_path: str, ): assert check_argument_types() super().__init__() if not is_transformers_available: raise ImportError( "`transformers` is not available. Please install it via `pip install" " transformers` or `cd /path/to/espnet/tools && . ./activate_python.sh" " && ./installers/install_transformers.sh`." ) model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) if hasattr(model, "model"): self.decoder = model.model.decoder else: self.decoder = model.decoder self.lm_head = model.lm_head self.model_name_or_path = model_name_or_path self.decoder_pretrained_params = copy.deepcopy(self.decoder.state_dict()) self.lm_head_pretrained_params = copy.deepcopy(self.lm_head.state_dict()) if encoder_output_size != self.decoder.config.hidden_size: self.linear_in = torch.nn.Linear( encoder_output_size, self.decoder.config.hidden_size ) else: self.linear_in = torch.nn.Identity()
[docs] def forward( self, hs_pad: torch.Tensor, hlens: torch.Tensor, ys_in_pad: torch.Tensor, ys_in_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward decoder. Args: hs_pad: encoded memory, float32 (batch, maxlen_in, feat) hlens: (batch) ys_in_pad: input tensor (batch, maxlen_out, #mels) ys_in_lens: (batch) Returns: (tuple): tuple containing: x: decoded token score before softmax (batch, maxlen_out, token) if use_output_layer is True, olens: (batch, ) """ args = {"return_dict": True} if self.decoder.__class__.__name__ == "MBartDecoder": ys_in_pad[:, 0] = 2 args["input_ids"] = ys_in_pad mask = (~make_pad_mask(ys_in_lens)).to(ys_in_pad.device).float() args["attention_mask"] = mask args["encoder_hidden_states"] = self.linear_in(hs_pad) hs_mask = (~make_pad_mask(hlens)).to(hs_pad.device).float() args["encoder_attention_mask"] = hs_mask x = self.decoder(**args).last_hidden_state x = self.lm_head(x) olens = mask.sum(1).to(torch.int) return x, olens
[docs] def reload_pretrained_parameters(self): self.decoder.load_state_dict(self.decoder_pretrained_params) self.lm_head.load_state_dict(self.lm_head_pretrained_params) logging.info("Pretrained Transformers model parameters reloaded!")