espnet2.st package¶
espnet2.st.espnet_model¶
-
class
espnet2.st.espnet_model.
ESPnetSTModel
(vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[espnet2.asr.frontend.abs_frontend.AbsFrontend], specaug: Optional[espnet2.asr.specaug.abs_specaug.AbsSpecAug], normalize: Optional[espnet2.layers.abs_normalize.AbsNormalize], preencoder: Optional[espnet2.asr.preencoder.abs_preencoder.AbsPreEncoder], encoder: espnet2.asr.encoder.abs_encoder.AbsEncoder, postencoder: Optional[espnet2.asr.postencoder.abs_postencoder.AbsPostEncoder], decoder: espnet2.asr.decoder.abs_decoder.AbsDecoder, extra_asr_decoder: Optional[espnet2.asr.decoder.abs_decoder.AbsDecoder], extra_mt_decoder: Optional[espnet2.asr.decoder.abs_decoder.AbsDecoder], ctc: Optional[espnet2.asr.ctc.CTC], src_vocab_size: Optional[int], src_token_list: Union[Tuple[str, ...], List[str], None], asr_weight: float = 0.0, mt_weight: float = 0.0, mtlalpha: float = 0.0, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, report_bleu: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', extract_feats_in_collect_stats: bool = True)[source]¶ Bases:
espnet2.train.abs_espnet_model.AbsESPnetModel
CTC-attention hybrid Encoder-Decoder model
-
collect_feats
(speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, src_text: Optional[torch.Tensor] = None, src_text_lengths: Optional[torch.Tensor] = None, **kwargs) → Dict[str, torch.Tensor][source]¶
-
encode
(speech: torch.Tensor, speech_lengths: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor][source]¶ Frontend + Encoder. Note that this method is used by st_inference.py
- Parameters:
speech – (Batch, Length, …)
speech_lengths – (Batch, )
-
forward
(speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, src_text: Optional[torch.Tensor] = None, src_text_lengths: Optional[torch.Tensor] = None, **kwargs) → Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor][source]¶ Frontend + Encoder + Decoder + Calc loss
- Parameters:
speech – (Batch, Length, …)
speech_lengths – (Batch,)
text – (Batch, Length)
text_lengths – (Batch,)
src_text – (Batch, length)
src_text_lengths – (Batch,)
kwargs – “utt_id” is among the input.
-