# Copyright 2023 Jee-weon Jung
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from typing import Dict, Optional, Tuple
import torch
from typeguard import check_argument_types
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.spk.loss.aamsoftmax import AAMSoftmax
from espnet2.spk.loss.abs_loss import AbsLoss
from espnet2.spk.pooling.abs_pooling import AbsPooling
from espnet2.spk.projector.abs_projector import AbsProjector
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
[docs]class ESPnetSpeakerModel(AbsESPnetModel):
"""
Speaker embedding extraction model.
Core model for diverse speaker-related tasks (e.g., verification, open-set
identification, diarization)
The model architecture comprises mainly 'encoder', 'pooling', and
'projector'.
In common speaker recognition field, the combination of three would be
usually named as 'speaker_encoder' (or speaker embedding extractor).
We splitted it into three for flexibility in future extensions:
- 'encoder' : extract frame-level speaker embeddings.
- 'pooling' : aggregate into single utterance-level embedding.
- 'projector' : (optional) additional processing (e.g., one fully-
connected layer) to derive speaker embedding.
Possibly, in the future, 'pooling' and/or 'projector' can be integrated as
a 'decoder', depending on the extension for joint usage of different tasks
(e.g., ASR, SE, target speaker extraction).
"""
def __init__(
self,
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
encoder: Optional[AbsEncoder],
pooling: Optional[AbsPooling],
projector: Optional[AbsProjector],
loss: Optional[AbsLoss],
):
assert check_argument_types()
super().__init__()
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.encoder = encoder
self.pooling = pooling
self.projector = projector
self.loss = loss
[docs] def forward(
self,
speech: torch.Tensor,
# speech_lengths: torch.Tensor = None,
spk_labels: torch.Tensor,
extract_embd: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""
Feed-forward through encoder layers and aggregate into utterance-level
feature.
Args:
speech: (Batch, samples)
speech_lengths: (Batch,)
extract_embd: a flag which doesn't go through the classification
head when set True
spk_labels: (Batch, )
"""
if spk_labels is not None:
assert speech.shape[0] == spk_labels.shape[0], (
speech.shape,
spk_labels.shape,
)
batch_size = speech.shape[0]
# 1. extract low-level feats (e.g., mel-spectrogram or MFCC)
# Will do nothing for raw waveform-based models (e.g., RawNets)
feats, _ = self.extract_feats(speech, None)
frame_level_feats = self.encode_frame(feats)
# 2. aggregation into utterance-level
utt_level_feat = self.pooling(frame_level_feats)
# 3. (optionally) go through further projection(s)
spk_embd = self.project_spk_embd(utt_level_feat)
if extract_embd:
return spk_embd
# 4. calculate loss
loss = self.loss(spk_embd, spk_labels)
stats = dict(loss=loss.detach())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
[docs] def encode_frame(self, feats: torch.Tensor) -> torch.Tensor:
frame_level_feats = self.encoder(feats)
return frame_level_feats
[docs] def aggregate(self, frame_level_feats: torch.Tensor) -> torch.Tensor:
utt_level_feat = self.aggregator(frame_level_feats)
return utt_level_feat
[docs] def project_spk_embd(self, utt_level_feat: torch.Tensor) -> torch.Tensor:
if self.projector is not None:
spk_embd = self.projector(utt_level_feat)
else:
spk_embd = utt_level_feat
return spk_embd
[docs] def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
spk_labels: torch.Tensor = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self.extract_feats(speech, speech_lengths)
return {"feats": feats}