"""Search algorithms for Transducer models."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from espnet2.asr_transducer.decoder.abs_decoder import AbsDecoder
from espnet2.asr_transducer.joint_network import JointNetwork
[docs]@dataclass
class Hypothesis:
"""Default hypothesis definition for Transducer search algorithms.
Args:
score: Total log-probability.
yseq: Label sequence as integer ID sequence.
dec_state: RNN/MEGA Decoder state (None if Stateless).
lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None
"""
score: float
yseq: List[int]
dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None
lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None
[docs]@dataclass
class ExtendedHypothesis(Hypothesis):
"""Extended hypothesis definition for NSC beam search and mAES.
Args:
: Hypothesis dataclass arguments.
dec_out: Decoder output sequence. (B, D_dec)
lm_score: Log-probabilities of the LM for given label. (vocab_size)
"""
dec_out: torch.Tensor = None
lm_score: torch.Tensor = None
[docs]class BeamSearchTransducer:
"""Beam search implementation for Transducer.
Args:
decoder: Decoder module.
joint_network: Joint network module.
beam_size: Size of the beam.
lm: LM module.
lm_weight: LM weight for soft fusion.
search_type: Search algorithm to use during inference.
max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
u_max: Maximum expected target sequence length. (ALSD)
nstep: Number of maximum expansion steps at each time step. (mAES)
expansion_gamma: Allowed logp difference for prune-by-value method. (mAES)
expansion_beta:
Number of additional candidates for expanded hypotheses selection. (mAES)
score_norm: Normalize final scores by length.
nbest: Number of final hypothesis.
streaming: Whether to perform chunk-by-chunk beam search.
"""
def __init__(
self,
decoder: AbsDecoder,
joint_network: JointNetwork,
beam_size: int,
lm: Optional[torch.nn.Module] = None,
lm_weight: float = 0.1,
search_type: str = "default",
max_sym_exp: int = 3,
u_max: int = 50,
nstep: int = 2,
expansion_gamma: float = 2.3,
expansion_beta: int = 2,
score_norm: bool = False,
nbest: int = 1,
streaming: bool = False,
) -> None:
"""Construct a BeamSearchTransducer object."""
super().__init__()
self.decoder = decoder
self.joint_network = joint_network
self.vocab_size = decoder.vocab_size
assert beam_size <= self.vocab_size, (
"beam_size (%d) should be smaller than or equal to vocabulary size (%d)."
% (
beam_size,
self.vocab_size,
)
)
self.beam_size = beam_size
if search_type == "default":
self.search_algorithm = self.default_beam_search
elif search_type == "tsd":
assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (
max_sym_exp
)
self.max_sym_exp = max_sym_exp
self.search_algorithm = self.time_sync_decoding
elif search_type == "alsd":
assert not streaming, "ALSD is not available in streaming mode."
assert u_max >= 0, "u_max should be a positive integer, a portion of max_T."
self.u_max = u_max
self.search_algorithm = self.align_length_sync_decoding
elif search_type == "maes":
assert self.vocab_size >= beam_size + expansion_beta, (
"beam_size (%d) + expansion_beta (%d) "
" should be smaller than or equal to vocab size (%d)."
% (beam_size, expansion_beta, self.vocab_size)
)
self.max_candidates = beam_size + expansion_beta
self.nstep = nstep
self.expansion_gamma = expansion_gamma
self.search_algorithm = self.modified_adaptive_expansion_search
else:
raise NotImplementedError(
"Specified search type (%s) is not supported." % search_type
)
self.use_lm = lm is not None
if self.use_lm:
assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported."
self.sos = self.vocab_size - 1
self.lm = lm
self.lm_weight = lm_weight
self.score_norm = score_norm
self.nbest = nbest
self.reset_cache()
def __call__(
self,
enc_out: torch.Tensor,
is_final: bool = True,
) -> List[Hypothesis]:
"""Perform beam search.
Args:
enc_out: Encoder output sequence. (T, D_enc)
is_final: Whether enc_out is the final chunk of data.
Returns:
nbest_hyps: N-best decoding results
"""
self.decoder.set_device(enc_out.device)
hyps = self.search_algorithm(enc_out)
if is_final:
self.reset_cache()
return self.sort_nbest(hyps)
self.search_cache = hyps
return hyps
[docs] def reset_cache(self) -> None:
"""Reset cache for streaming decoding."""
self.decoder.score_cache = {}
self.search_cache = None
[docs] def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
"""Sort in-place hypotheses by score or score given sequence length.
Args:
hyps: Hypothesis.
Return:
hyps: Sorted hypothesis.
"""
if self.score_norm:
hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
else:
hyps.sort(key=lambda x: x.score, reverse=True)
return hyps[: self.nbest]
[docs] def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
"""Recombine hypotheses with same label ID sequence.
Args:
hyps: Hypotheses.
Returns:
final: Recombined hypotheses.
"""
final = {}
for hyp in hyps:
str_yseq = "_".join(map(str, hyp.yseq))
if str_yseq in final:
final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score)
else:
final[str_yseq] = hyp
return [*final.values()]
[docs] def select_k_expansions(
self,
hyps: List[ExtendedHypothesis],
topk_idx: torch.Tensor,
topk_logp: torch.Tensor,
) -> List[ExtendedHypothesis]:
"""Return K hypotheses candidates for expansion from a list of hypothesis.
K candidates are selected according to the extended hypotheses probabilities
and a prune-by-value method. Where K is equal to beam_size + beta.
Args:
hyps: Hypotheses.
topk_idx: Indices of candidates hypothesis.
topk_logp: Log-probabilities of candidates hypothesis.
Returns:
k_expansions: Best K expansion hypotheses candidates.
"""
k_expansions = []
for i, hyp in enumerate(hyps):
hyp_i = [
(int(k), hyp.score + float(v))
for k, v in zip(topk_idx[i], topk_logp[i])
]
k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
k_expansions.append(
sorted(
filter(
lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i
),
key=lambda x: x[1],
reverse=True,
)
)
return k_expansions
[docs] def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
"""Beam search implementation without prefix search.
Modified from https://arxiv.org/pdf/1211.3711.pdf
Args:
enc_out: Encoder output sequence. (T, D)
Returns:
nbest_hyps: N-best hypothesis.
"""
beam_k = min(self.beam_size, (self.vocab_size - 1))
max_t = len(enc_out)
if self.search_cache is not None:
kept_hyps = self.search_cache
else:
kept_hyps = [
Hypothesis(
score=0.0,
yseq=[0],
dec_state=self.decoder.init_state(1),
)
]
for t in range(max_t):
hyps = kept_hyps
kept_hyps = []
while True:
max_hyp = max(hyps, key=lambda x: x.score)
hyps.remove(max_hyp)
dec_out, state = self.decoder.score(
max_hyp.yseq,
max_hyp.dec_state,
)
logp = torch.log_softmax(
self.joint_network(enc_out[t : t + 1, :], dec_out),
dim=-1,
).squeeze(0)
top_k = logp[1:].topk(beam_k, dim=-1)
kept_hyps.append(
Hypothesis(
score=(max_hyp.score + float(logp[0:1])),
yseq=max_hyp.yseq,
dec_state=max_hyp.dec_state,
lm_state=max_hyp.lm_state,
)
)
if self.use_lm:
lm_scores, lm_state = self.lm.score(
torch.LongTensor(
[self.sos] + max_hyp.yseq[1:], device=self.decoder.device
),
max_hyp.lm_state,
None,
)
else:
lm_state = max_hyp.lm_state
for logp, k in zip(*top_k):
score = max_hyp.score + float(logp)
if self.use_lm:
score += self.lm_weight * lm_scores[k + 1]
hyps.append(
Hypothesis(
score=score,
yseq=max_hyp.yseq + [int(k + 1)],
dec_state=state,
lm_state=lm_state,
)
)
hyps_max = float(max(hyps, key=lambda x: x.score).score)
kept_most_prob = sorted(
[hyp for hyp in kept_hyps if hyp.score > hyps_max],
key=lambda x: x.score,
)
if len(kept_most_prob) >= self.beam_size:
kept_hyps = kept_most_prob
break
return kept_hyps
[docs] def align_length_sync_decoding(
self,
enc_out: torch.Tensor,
) -> List[Hypothesis]:
"""Alignment-length synchronous beam search implementation.
Based on https://ieeexplore.ieee.org/document/9053040
Args:
h: Encoder output sequences. (T, D)
Returns:
nbest_hyps: N-best hypothesis.
"""
t_max = int(enc_out.size(0))
u_max = min(self.u_max, (t_max - 1))
B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))]
final = []
if self.use_lm:
B[0].lm_state = self.lm.zero_state()
for i in range(t_max + u_max):
A = []
B_ = []
B_enc_out = []
for hyp in B:
u = len(hyp.yseq) - 1
t = i - u
if t > (t_max - 1):
continue
B_.append(hyp)
B_enc_out.append((t, enc_out[t]))
if B_:
beam_enc_out = torch.stack([b[1] for b in B_enc_out])
beam_dec_out, beam_state = self.decoder.batch_score(B_)
beam_logp = torch.log_softmax(
self.joint_network(beam_enc_out, beam_dec_out),
dim=-1,
)
beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
if self.use_lm:
beam_lm_scores, beam_lm_states = self.lm.batch_score(
self.create_lm_batch_inputs([b.yseq for b in B_]),
[b.lm_state for b in B_],
None,
)
for i, hyp in enumerate(B_):
new_hyp = Hypothesis(
score=(hyp.score + float(beam_logp[i, 0])),
yseq=hyp.yseq[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
)
A.append(new_hyp)
if B_enc_out[i][0] == (t_max - 1):
final.append(new_hyp)
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
new_hyp = Hypothesis(
score=(hyp.score + float(logp)),
yseq=(hyp.yseq[:] + [int(k)]),
dec_state=self.decoder.select_state(beam_state, i),
lm_state=hyp.lm_state,
)
if self.use_lm:
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
new_hyp.lm_state = beam_lm_states[i]
A.append(new_hyp)
B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
B = self.recombine_hyps(B)
if final:
return final
return B
[docs] def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
"""Time synchronous beam search implementation.
Based on https://ieeexplore.ieee.org/document/9053040
Args:
enc_out: Encoder output sequence. (T, D)
Returns:
nbest_hyps: N-best hypothesis.
"""
if self.search_cache is not None:
B = self.search_cache
else:
B = [
Hypothesis(
yseq=[0],
score=0.0,
dec_state=self.decoder.init_state(1),
)
]
if self.use_lm:
B[0].lm_state = self.lm.zero_state()
for enc_out_t in enc_out:
A = []
C = B
enc_out_t = enc_out_t.unsqueeze(0)
for v in range(self.max_sym_exp):
D = []
beam_dec_out, beam_state = self.decoder.batch_score(C)
beam_logp = torch.log_softmax(
self.joint_network(enc_out_t, beam_dec_out),
dim=-1,
)
beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
seq_A = [h.yseq for h in A]
for i, hyp in enumerate(C):
if hyp.yseq not in seq_A:
A.append(
Hypothesis(
score=(hyp.score + float(beam_logp[i, 0])),
yseq=hyp.yseq[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
)
)
else:
dict_pos = seq_A.index(hyp.yseq)
A[dict_pos].score = np.logaddexp(
A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
)
if v < (self.max_sym_exp - 1):
if self.use_lm:
beam_lm_scores, beam_lm_states = self.lm.batch_score(
self.create_lm_batch_inputs([c.yseq for c in C]),
[c.lm_state for c in C],
None,
)
for i, hyp in enumerate(C):
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
new_hyp = Hypothesis(
score=(hyp.score + float(logp)),
yseq=(hyp.yseq + [int(k)]),
dec_state=self.decoder.select_state(beam_state, i),
lm_state=hyp.lm_state,
)
if self.use_lm:
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
new_hyp.lm_state = beam_lm_states[i]
D.append(new_hyp)
C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size]
B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
return B
[docs] def modified_adaptive_expansion_search(
self,
enc_out: torch.Tensor,
) -> List[ExtendedHypothesis]:
"""Modified version of Adaptive Expansion Search (mAES).
Based on AES (https://ieeexplore.ieee.org/document/9250505) and
NSC (https://arxiv.org/abs/2201.05420).
Args:
enc_out: Encoder output sequence. (T, D_enc)
Returns:
nbest_hyps: N-best hypothesis.
"""
if self.search_cache is not None:
kept_hyps = self.search_cache
else:
init_tokens = [
ExtendedHypothesis(
yseq=[0],
score=0.0,
dec_state=self.decoder.init_state(1),
)
]
beam_dec_out, beam_state = self.decoder.batch_score(
init_tokens,
)
if self.use_lm:
beam_lm_scores, beam_lm_states = self.lm.batch_score(
self.create_lm_batch_inputs([h.yseq for h in init_tokens]),
[h.lm_state for h in init_tokens],
None,
)
lm_state = beam_lm_states[0]
lm_score = beam_lm_scores[0]
else:
lm_state = None
lm_score = None
kept_hyps = [
ExtendedHypothesis(
yseq=[0],
score=0.0,
dec_state=self.decoder.select_state(beam_state, 0),
dec_out=beam_dec_out[0],
lm_state=lm_state,
lm_score=lm_score,
)
]
for enc_out_t in enc_out:
hyps = kept_hyps
kept_hyps = []
beam_enc_out = enc_out_t.unsqueeze(0)
list_b = []
for n in range(self.nstep):
beam_dec_out = torch.stack([h.dec_out for h in hyps])
beam_logp, beam_idx = torch.log_softmax(
self.joint_network(beam_enc_out, beam_dec_out),
dim=-1,
).topk(self.max_candidates, dim=-1)
k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp)
list_exp = []
for i, hyp in enumerate(hyps):
for k, new_score in k_expansions[i]:
new_hyp = ExtendedHypothesis(
yseq=hyp.yseq[:],
score=new_score,
dec_out=hyp.dec_out,
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
lm_score=hyp.lm_score,
)
if k == 0:
list_b.append(new_hyp)
else:
new_hyp.yseq.append(int(k))
if self.use_lm:
new_hyp.score += self.lm_weight * float(hyp.lm_score[k])
list_exp.append(new_hyp)
if not list_exp:
kept_hyps = sorted(
self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True
)[: self.beam_size]
break
else:
beam_dec_out, beam_state = self.decoder.batch_score(
list_exp,
)
if self.use_lm:
beam_lm_scores, beam_lm_states = self.lm.batch_score(
self.create_lm_batch_inputs([h.yseq for h in list_exp]),
[h.lm_state for h in list_exp],
None,
)
if n < (self.nstep - 1):
for i, hyp in enumerate(list_exp):
hyp.dec_out = beam_dec_out[i]
hyp.dec_state = self.decoder.select_state(beam_state, i)
if self.use_lm:
hyp.lm_state = beam_lm_states[i]
hyp.lm_score = beam_lm_scores[i]
hyps = list_exp[:]
else:
beam_logp = torch.log_softmax(
self.joint_network(beam_enc_out, beam_dec_out),
dim=-1,
)
for i, hyp in enumerate(list_exp):
hyp.score += float(beam_logp[i, 0])
hyp.dec_out = beam_dec_out[i]
hyp.dec_state = self.decoder.select_state(beam_state, i)
if self.use_lm:
hyp.lm_state = beam_lm_states[i]
hyp.lm_score = beam_lm_scores[i]
kept_hyps = sorted(
self.recombine_hyps(list_b + list_exp),
key=lambda x: x.score,
reverse=True,
)[: self.beam_size]
return kept_hyps