#!/usr/bin/env python3
"""
This script is used to construct End-to-End models of multi-speaker ASR.
Copyright 2017 Johns Hopkins University (Shinji Watanabe)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import argparse
import logging
import math
import os
import sys
from itertools import groupby
import numpy as np
import torch
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.e2e_asr_common import get_vgg2l_odim, label_smoothing_dist
from espnet.nets.pytorch_backend.ctc import ctc_for
from espnet.nets.pytorch_backend.e2e_asr import E2E as E2EASR
from espnet.nets.pytorch_backend.e2e_asr import Reporter
from espnet.nets.pytorch_backend.frontends.feature_transform import ( # noqa: H301
feature_transform_for,
)
from espnet.nets.pytorch_backend.frontends.frontend import frontend_for
from espnet.nets.pytorch_backend.initialization import (
lecun_normal_init_parameters,
set_forget_bias_to_one,
)
from espnet.nets.pytorch_backend.nets_utils import (
get_subsample,
make_pad_mask,
pad_list,
to_device,
to_torch_tensor,
)
from espnet.nets.pytorch_backend.rnn.attentions import att_for
from espnet.nets.pytorch_backend.rnn.decoders import decoder_for
from espnet.nets.pytorch_backend.rnn.encoders import RNNP, VGG2L
from espnet.nets.pytorch_backend.rnn.encoders import encoder_for as encoder_for_single
CTC_LOSS_THRESHOLD = 10000
[docs]class PIT(object):
"""Permutation Invariant Training (PIT) module.
:parameter int num_spkrs: number of speakers for PIT process (2 or 3)
"""
def __init__(self, num_spkrs):
"""Initialize PIT module."""
self.num_spkrs = num_spkrs
# [[0, 1], [1, 0]] or
# [[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]]
self.perm_choices = []
initial_seq = np.linspace(0, num_spkrs - 1, num_spkrs, dtype=np.int64)
self.permutationDFS(initial_seq, 0)
# [[0, 3], [1, 2]] or
# [[0, 4, 8], [0, 5, 7], [1, 3, 8], [1, 5, 6], [2, 4, 6], [2, 3, 7]]
self.loss_perm_idx = np.linspace(
0, num_spkrs * (num_spkrs - 1), num_spkrs, dtype=np.int64
).reshape(1, num_spkrs)
self.loss_perm_idx = (self.loss_perm_idx + np.array(self.perm_choices)).tolist()
[docs] def min_pit_sample(self, loss):
"""Compute the PIT loss for each sample.
:param 1-D torch.Tensor loss: list of losses for one sample,
including [h1r1, h1r2, h2r1, h2r2] or
[h1r1, h1r2, h1r3, h2r1, h2r2, h2r3, h3r1, h3r2, h3r3]
:return minimum loss of best permutation
:rtype torch.Tensor (1)
:return the best permutation
:rtype List: len=2
"""
score_perms = (
torch.stack(
[torch.sum(loss[loss_perm_idx]) for loss_perm_idx in self.loss_perm_idx]
)
/ self.num_spkrs
)
perm_loss, min_idx = torch.min(score_perms, 0)
permutation = self.perm_choices[min_idx]
return perm_loss, permutation
[docs] def pit_process(self, losses):
"""Compute the PIT loss for a batch.
:param torch.Tensor losses: losses (B, 1|4|9)
:return minimum losses of a batch with best permutation
:rtype torch.Tensor (B)
:return the best permutation
:rtype torch.LongTensor (B, 1|2|3)
"""
bs = losses.size(0)
ret = [self.min_pit_sample(losses[i]) for i in range(bs)]
loss_perm = torch.stack([r[0] for r in ret], dim=0).to(losses.device) # (B)
permutation = torch.tensor([r[1] for r in ret]).long().to(losses.device)
return torch.mean(loss_perm), permutation
[docs] def permutationDFS(self, source, start):
"""Get permutations with DFS.
The final result is all permutations of the 'source' sequence.
e.g. [[1, 2], [2, 1]] or
[[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 2, 1], [3, 1, 2]]
:param np.ndarray source: (num_spkrs, 1), e.g. [1, 2, ..., N]
:param int start: the start point to permute
"""
if start == len(source) - 1: # reach final state
self.perm_choices.append(source.tolist())
for i in range(start, len(source)):
# swap values at position start and i
source[start], source[i] = source[i], source[start]
self.permutationDFS(source, start + 1)
# reverse the swap
source[start], source[i] = source[i], source[start]
[docs]class E2E(ASRInterface, torch.nn.Module):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
[docs] @staticmethod
def add_arguments(parser):
"""Add arguments."""
E2EASR.encoder_add_arguments(parser)
E2E.encoder_mix_add_arguments(parser)
E2EASR.attention_add_arguments(parser)
E2EASR.decoder_add_arguments(parser)
return parser
[docs] @staticmethod
def encoder_mix_add_arguments(parser):
"""Add arguments for multi-speaker encoder."""
group = parser.add_argument_group("E2E encoder setting for multi-speaker")
# asr-mix encoder
group.add_argument(
"--spa",
action="store_true",
help="Enable speaker parallel attention "
"for multi-speaker speech recognition task.",
)
group.add_argument(
"--elayers-sd",
default=4,
type=int,
help="Number of speaker differentiate encoder layers"
"for multi-speaker speech recognition task.",
)
return parser
[docs] def get_total_subsampling_factor(self):
"""Get total subsampling factor."""
return self.enc.conv_subsampling_factor * int(np.prod(self.subsample))
def __init__(self, idim, odim, args):
"""Initialize multi-speaker E2E module."""
super(E2E, self).__init__()
torch.nn.Module.__init__(self)
self.mtlalpha = args.mtlalpha
assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
self.etype = args.etype
self.verbose = args.verbose
# NOTE: for self.build method
args.char_list = getattr(args, "char_list", None)
self.char_list = args.char_list
self.outdir = args.outdir
self.space = args.sym_space
self.blank = args.sym_blank
self.reporter = Reporter()
self.num_spkrs = args.num_spkrs
self.spa = args.spa
self.pit = PIT(self.num_spkrs)
# below means the last number becomes eos/sos ID
# note that sos/eos IDs are identical
self.sos = odim - 1
self.eos = odim - 1
# subsample info
self.subsample = get_subsample(args, mode="asr", arch="rnn_mix")
# label smoothing info
if args.lsm_type and os.path.isfile(args.train_json):
logging.info("Use label smoothing with " + args.lsm_type)
labeldist = label_smoothing_dist(
odim, args.lsm_type, transcript=args.train_json
)
else:
labeldist = None
if getattr(args, "use_frontend", False): # use getattr to keep compatibility
self.frontend = frontend_for(args, idim)
self.feature_transform = feature_transform_for(args, (idim - 1) * 2)
idim = args.n_mels
else:
self.frontend = None
# encoder
self.enc = encoder_for(args, idim, self.subsample)
# ctc
self.ctc = ctc_for(args, odim, reduce=False)
# attention
num_att = self.num_spkrs if args.spa else 1
self.att = att_for(args, num_att)
# decoder
self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist)
# weight initialization
self.init_like_chainer()
# options for beam search
if "report_cer" in vars(args) and (args.report_cer or args.report_wer):
recog_args = {
"beam_size": args.beam_size,
"penalty": args.penalty,
"ctc_weight": args.ctc_weight,
"maxlenratio": args.maxlenratio,
"minlenratio": args.minlenratio,
"lm_weight": args.lm_weight,
"rnnlm": args.rnnlm,
"nbest": args.nbest,
"space": args.sym_space,
"blank": args.sym_blank,
}
self.recog_args = argparse.Namespace(**recog_args)
self.report_cer = args.report_cer
self.report_wer = args.report_wer
else:
self.report_cer = False
self.report_wer = False
self.rnnlm = None
self.logzero = -10000000000.0
self.loss = None
self.acc = None
[docs] def init_like_chainer(self):
"""Initialize weight like chainer.
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
however, there are two exceptions as far as I know.
- EmbedID.W ~ Normal(0, 1)
- LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
"""
lecun_normal_init_parameters(self)
# exceptions
# embed weight ~ Normal(0, 1)
self.dec.embed.weight.data.normal_(0, 1)
# forget-bias = 1.0
# https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
for i in range(len(self.dec.decoder)):
set_forget_bias_to_one(self.dec.decoder[i].bias_ih)
[docs] def forward(self, xs_pad, ilens, ys_pad):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, num_spkrs, Lmax)
:return: ctc loss value
:rtype: torch.Tensor
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy in attention decoder
:rtype: float
"""
import editdistance
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
if isinstance(hs_pad, list):
hlens_n = [None] * self.num_spkrs
for i in range(self.num_spkrs):
hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens)
hlens = hlens_n
else:
hs_pad, hlens = self.feature_transform(hs_pad, hlens)
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
if not isinstance(
hs_pad, list
): # single-channel input xs_pad (single- or multi-speaker)
hs_pad, hlens, _ = self.enc(hs_pad, hlens)
else: # multi-channel multi-speaker input xs_pad
for i in range(self.num_spkrs):
hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i])
# 2. CTC loss
if self.mtlalpha == 0:
loss_ctc, min_perm = None, None
else:
if not isinstance(hs_pad, list): # single-speaker input xs_pad
loss_ctc = torch.mean(self.ctc(hs_pad, hlens, ys_pad))
else: # multi-speaker input xs_pad
ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax)
loss_ctc_perm = torch.stack(
[
self.ctc(
hs_pad[i // self.num_spkrs],
hlens[i // self.num_spkrs],
ys_pad[i % self.num_spkrs],
)
for i in range(self.num_spkrs**2)
],
dim=1,
) # (B, num_spkrs^2)
loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm)
logging.info("ctc loss:" + str(float(loss_ctc)))
# 3. attention loss
if self.mtlalpha == 1:
loss_att = None
acc = None
else:
if not isinstance(hs_pad, list): # single-speaker input xs_pad
loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad)
else:
for i in range(ys_pad.size(1)): # B
ys_pad[:, i] = ys_pad[min_perm[i], i]
rslt = [
self.dec(hs_pad[i], hlens[i], ys_pad[i], strm_idx=i)
for i in range(self.num_spkrs)
]
loss_att = sum([r[0] for r in rslt]) / float(len(rslt))
acc = sum([r[1] for r in rslt]) / float(len(rslt))
self.acc = acc
# 4. compute cer without beam search
if self.mtlalpha == 0 or self.char_list is None:
cer_ctc = None
else:
cers = []
for ns in range(self.num_spkrs):
y_hats = self.ctc.argmax(hs_pad[ns]).data
for i, y in enumerate(y_hats):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[ns][i]
seq_hat = [
self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.blank, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
hyp_chars = seq_hat_text.replace(" ", "")
ref_chars = seq_true_text.replace(" ", "")
if len(ref_chars) > 0:
cers.append(
editdistance.eval(hyp_chars, ref_chars) / len(ref_chars)
)
cer_ctc = sum(cers) / len(cers) if cers else None
# 5. compute cer/wer
if (
self.training
or not (self.report_cer or self.report_wer)
or not isinstance(hs_pad, list)
):
cer, wer = 0.0, 0.0
else:
if self.recog_args.ctc_weight > 0.0:
lpz = [
self.ctc.log_softmax(hs_pad[i]).data for i in range(self.num_spkrs)
]
else:
lpz = None
word_eds, char_eds, word_ref_lens, char_ref_lens = [], [], [], []
nbest_hyps = [
self.dec.recognize_beam_batch(
hs_pad[i],
torch.tensor(hlens[i]),
lpz[i],
self.recog_args,
self.char_list,
self.rnnlm,
strm_idx=i,
)
for i in range(self.num_spkrs)
]
# remove <sos> and <eos>
y_hats = [
[nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps[i]]
for i in range(self.num_spkrs)
]
for i in range(len(y_hats[0])):
hyp_words = []
hyp_chars = []
ref_words = []
ref_chars = []
for ns in range(self.num_spkrs):
y_hat = y_hats[ns][i]
y_true = ys_pad[ns][i]
seq_hat = [
self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ")
seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "")
seq_true_text = "".join(seq_true).replace(
self.recog_args.space, " "
)
hyp_words.append(seq_hat_text.split())
ref_words.append(seq_true_text.split())
hyp_chars.append(seq_hat_text.replace(" ", ""))
ref_chars.append(seq_true_text.replace(" ", ""))
tmp_word_ed = [
editdistance.eval(
hyp_words[ns // self.num_spkrs], ref_words[ns % self.num_spkrs]
)
for ns in range(self.num_spkrs**2)
] # h1r1,h1r2,h2r1,h2r2
tmp_char_ed = [
editdistance.eval(
hyp_chars[ns // self.num_spkrs], ref_chars[ns % self.num_spkrs]
)
for ns in range(self.num_spkrs**2)
] # h1r1,h1r2,h2r1,h2r2
word_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_word_ed))[0])
word_ref_lens.append(len(sum(ref_words, [])))
char_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_char_ed))[0])
char_ref_lens.append(len("".join(ref_chars)))
wer = (
0.0
if not self.report_wer
else float(sum(word_eds)) / sum(word_ref_lens)
)
cer = (
0.0
if not self.report_cer
else float(sum(char_eds)) / sum(char_ref_lens)
)
alpha = self.mtlalpha
if alpha == 0:
self.loss = loss_att
loss_att_data = float(loss_att)
loss_ctc_data = None
elif alpha == 1:
self.loss = loss_ctc
loss_att_data = None
loss_ctc_data = float(loss_ctc)
else:
self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
loss_att_data = float(loss_att)
loss_ctc_data = float(loss_ctc)
loss_data = float(self.loss)
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
self.reporter.report(
loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data
)
else:
logging.warning("loss (=%f) is not correct", loss_data)
return self.loss
[docs] def recognize(self, x, recog_args, char_list, rnnlm=None):
"""E2E beam search.
:param ndarray x: input acoustic feature (T, D)
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev = self.training
self.eval()
ilens = [x.shape[0]]
# subsample frame
x = x[:: self.subsample[0], :]
h = to_device(self, to_torch_tensor(x).float())
# make a utt list (1) to use the same interface for encoder
hs = h.contiguous().unsqueeze(0)
# 0. Frontend
if self.frontend is not None:
hs, hlens, mask = self.frontend(hs, ilens)
hlens_n = [None] * self.num_spkrs
for i in range(self.num_spkrs):
hs[i], hlens_n[i] = self.feature_transform(hs[i], hlens)
hlens = hlens_n
else:
hs, hlens = hs, ilens
# 1. Encoder
if not isinstance(hs, list): # single-channel multi-speaker input x
hs, hlens, _ = self.enc(hs, hlens)
else: # multi-channel multi-speaker input x
for i in range(self.num_spkrs):
hs[i], hlens[i], _ = self.enc(hs[i], hlens[i])
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
lpz = [self.ctc.log_softmax(i)[0] for i in hs]
else:
lpz = None
# 2. decoder
# decode the first utterance
y = [
self.dec.recognize_beam(
hs[i][0], lpz[i], recog_args, char_list, rnnlm, strm_idx=i
)
for i in range(self.num_spkrs)
]
if prev:
self.train()
return y
[docs] def recognize_batch(self, xs, recog_args, char_list, rnnlm=None):
"""E2E beam search.
:param ndarray xs: input acoustic feature (T, D)
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev = self.training
self.eval()
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
# subsample frame
xs = [xx[:: self.subsample[0], :] for xx in xs]
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
xs_pad = pad_list(xs, 0.0)
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(xs_pad, ilens)
hlens_n = [None] * self.num_spkrs
for i in range(self.num_spkrs):
hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens)
hlens = hlens_n
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
if not isinstance(hs_pad, list): # single-channel multi-speaker input x
hs_pad, hlens, _ = self.enc(hs_pad, hlens)
else: # multi-channel multi-speaker input x
for i in range(self.num_spkrs):
hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i])
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
lpz = [self.ctc.log_softmax(hs_pad[i]) for i in range(self.num_spkrs)]
normalize_score = False
else:
lpz = None
normalize_score = True
# 2. decoder
y = [
self.dec.recognize_beam_batch(
hs_pad[i],
hlens[i],
lpz[i],
recog_args,
char_list,
rnnlm,
normalize_score=normalize_score,
strm_idx=i,
)
for i in range(self.num_spkrs)
]
if prev:
self.train()
return y
[docs] def enhance(self, xs):
"""Forward only the frontend stage.
:param ndarray xs: input acoustic feature (T, C, F)
"""
if self.frontend is None:
raise RuntimeError("Frontend doesn't exist")
prev = self.training
self.eval()
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
# subsample frame
xs = [xx[:: self.subsample[0], :] for xx in xs]
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
xs_pad = pad_list(xs, 0.0)
enhanced, hlensm, mask = self.frontend(xs_pad, ilens)
if prev:
self.train()
if isinstance(enhanced, (tuple, list)):
enhanced = list(enhanced)
mask = list(mask)
for idx in range(len(enhanced)): # number of speakers
enhanced[idx] = enhanced[idx].cpu().numpy()
mask[idx] = mask[idx].cpu().numpy()
return enhanced, mask, ilens
return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
[docs] def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
"""E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, num_spkrs, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
with torch.no_grad():
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
hlens_n = [None] * self.num_spkrs
for i in range(self.num_spkrs):
hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens)
hlens = hlens_n
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
if not isinstance(hs_pad, list): # single-channel multi-speaker input x
hs_pad, hlens, _ = self.enc(hs_pad, hlens)
else: # multi-channel multi-speaker input x
for i in range(self.num_spkrs):
hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i])
# Permutation
ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax)
if self.num_spkrs <= 3:
loss_ctc = torch.stack(
[
self.ctc(
hs_pad[i // self.num_spkrs],
hlens[i // self.num_spkrs],
ys_pad[i % self.num_spkrs],
)
for i in range(self.num_spkrs**2)
],
1,
) # (B, num_spkrs^2)
loss_ctc, min_perm = self.pit.pit_process(loss_ctc)
for i in range(ys_pad.size(1)): # B
ys_pad[:, i] = ys_pad[min_perm[i], i]
# 2. Decoder
att_ws = [
self.dec.calculate_all_attentions(
hs_pad[i], hlens[i], ys_pad[i], strm_idx=i
)
for i in range(self.num_spkrs)
]
return att_ws
[docs]class EncoderMix(torch.nn.Module):
"""Encoder module for the case of multi-speaker mixture speech.
:param str etype: type of encoder network
:param int idim: number of dimensions of encoder network
:param int elayers_sd:
number of layers of speaker differentiate part in encoder network
:param int elayers_rec:
number of layers of shared recognition part in encoder network
:param int eunits: number of lstm units of encoder network
:param int eprojs: number of projection units of encoder network
:param np.ndarray subsample: list of subsampling numbers
:param float dropout: dropout rate
:param int in_channel: number of input channels
:param int num_spkrs: number of number of speakers
"""
def __init__(
self,
etype,
idim,
elayers_sd,
elayers_rec,
eunits,
eprojs,
subsample,
dropout,
num_spkrs=2,
in_channel=1,
):
"""Initialize the encoder of single-channel multi-speaker ASR."""
super(EncoderMix, self).__init__()
typ = etype.lstrip("vgg").rstrip("p")
if typ not in ["lstm", "gru", "blstm", "bgru"]:
logging.error("Error: need to specify an appropriate encoder architecture")
if etype.startswith("vgg"):
if etype[-1] == "p":
self.enc_mix = torch.nn.ModuleList([VGG2L(in_channel)])
self.enc_sd = torch.nn.ModuleList(
[
torch.nn.ModuleList(
[
RNNP(
get_vgg2l_odim(idim, in_channel=in_channel),
elayers_sd,
eunits,
eprojs,
subsample[: elayers_sd + 1],
dropout,
typ=typ,
)
]
)
for i in range(num_spkrs)
]
)
self.enc_rec = torch.nn.ModuleList(
[
RNNP(
eprojs,
elayers_rec,
eunits,
eprojs,
subsample[elayers_sd:],
dropout,
typ=typ,
)
]
)
logging.info("Use CNN-VGG + B" + typ.upper() + "P for encoder")
else:
logging.error(
f"Error: need to specify an appropriate encoder architecture. "
f"Illegal name {etype}"
)
sys.exit()
else:
logging.error(
f"Error: need to specify an appropriate encoder architecture. "
f"Illegal name {etype}"
)
sys.exit()
self.num_spkrs = num_spkrs
[docs] def forward(self, xs_pad, ilens):
"""Encodermix forward.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:return: list: batch of hidden state sequences [num_spkrs x (B, Tmax, eprojs)]
:rtype: torch.Tensor
"""
# mixture encoder
for module in self.enc_mix:
xs_pad, ilens, _ = module(xs_pad, ilens)
# SD and Rec encoder
xs_pad_sd = [xs_pad for i in range(self.num_spkrs)]
ilens_sd = [ilens for i in range(self.num_spkrs)]
for ns in range(self.num_spkrs):
# Encoder_SD: speaker differentiate encoder
for module in self.enc_sd[ns]:
xs_pad_sd[ns], ilens_sd[ns], _ = module(xs_pad_sd[ns], ilens_sd[ns])
# Encoder_Rec: recognition encoder
for module in self.enc_rec:
xs_pad_sd[ns], ilens_sd[ns], _ = module(xs_pad_sd[ns], ilens_sd[ns])
# make mask to remove bias value in padded part
mask = to_device(xs_pad, make_pad_mask(ilens_sd[0]).unsqueeze(-1))
return [x.masked_fill(mask, 0.0) for x in xs_pad_sd], ilens_sd, None
[docs]def encoder_for(args, idim, subsample):
"""Construct the encoder."""
if getattr(args, "use_frontend", False): # use getattr to keep compatibility
# with frontend, the mixed speech are separated as streams for each speaker
return encoder_for_single(args, idim, subsample)
else:
return EncoderMix(
args.etype,
idim,
args.elayers_sd,
args.elayers,
args.eunits,
args.eprojs,
subsample,
args.dropout_rate,
args.num_spkrs,
)