# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright 2018-2019, Mingkun Huang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import multiprocessing
from typing import Optional
import numba
import torch
from torch.autograd import Function
from espnet2.asr.transducer.rnnt_multi_blank.utils import global_constants
[docs]def log_sum_exp(a: torch.Tensor, b: torch.Tensor):
"""
Logsumexp with safety checks for infs.
"""
if torch.isinf(a):
return b
if torch.isinf(b):
return a
if a > b:
return math.log1p(math.exp(b - a)) + a
else:
return math.log1p(math.exp(a - b)) + b
[docs]class CpuRNNT_index:
def __init__(
self, U: int, maxU: int, minibatch: int, alphabet_size: int, batch_first: bool
):
"""
A placeholder Index computation class that emits the resolved index in a
flattened tensor, mimicing pointer indexing in CUDA kernels on the CPU.
Args:
U: Length of the current target sample (without padding).
maxU: Max Length of the padded target samples.
minibatch: Minibatch index
alphabet_size: Size of the vocabulary including RNNT blank - V+1.
batch_first: Bool flag determining if batch index is first or third.
"""
super(CpuRNNT_index, self).__init__()
self.U = U
self.maxU = maxU
self.minibatch = minibatch
self.alphabet_size = alphabet_size
self.batch_first = batch_first
def __call__(self, t: int, u: int, v: Optional[int] = None):
# if indexing all the values of the vocabulary, then only t, u are provided
if v is None:
return t * self.U + u
else:
# otherwise, t, u, v are provided to index particular value
# in the vocabulary.
if self.batch_first:
return (t * self.maxU + u) * self.alphabet_size + v
else:
return (t * self.maxU + u) * self.minibatch * self.alphabet_size + v
[docs]class LogSoftmaxGradModification(Function):
[docs] @staticmethod
def forward(ctx, acts, clamp):
if clamp < 0:
raise ValueError("`clamp` must be 0.0 or positive float.")
# This is needed for correctness (inplace is problematic),
# but it wastes a log of memory.
res = acts.new(acts)
ctx.clamp = clamp
return res
[docs] @staticmethod
def backward(ctx, grad_output):
# Clamp the gradients of loss(logsoftmax(...))
# CPU computes logsoftmax explicitly, so we need to override t
grad_output = torch.clamp(grad_output, -ctx.clamp, ctx.clamp)
return (
grad_output,
None,
)
[docs]class CPURNNT:
def __init__(
self,
minibatch: int,
maxT: int,
maxU: int,
alphabet_size: int,
workspace: torch.Tensor,
blank: int,
fastemit_lambda: float,
clamp: float,
num_threads: int,
batch_first: bool,
):
"""
Helper class to compute the Transducer Loss on CPU.
Args:
minibatch: Size of the minibatch b.
maxT: The maximum possible acoustic sequence length.
Represents T in the logprobs tensor.
maxU: The maximum possible target sequence length.
Represents U in the logprobs tensor.
alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank).
workspace: An allocated chunk of memory that will be sliced off and
reshaped into required blocks used as working memory.
blank: Index of the RNNT blank token in the vocabulary.
Generally the first or last token in the vocab.
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level
Emission Regularization.
clamp: Float value. When set to value >= 0.0, will clamp the
gradient to [-clamp, clamp].
num_threads: Number of OMP threads to launch.
batch_first: Bool that decides if batch dimension is first or third.
"""
self.minibatch_ = minibatch
self.maxT_ = maxT
self.maxU_ = maxU
self.alphabet_size_ = alphabet_size
# a flat vector of floatX numbers that represents allocated memory slices
self.workspace = workspace
self.blank_ = blank
self.fastemit_lambda_ = fastemit_lambda
self.clamp_ = abs(clamp)
self.num_threads_ = num_threads
self.batch_first = batch_first
if num_threads > 0:
numba.set_num_threads(min(multiprocessing.cpu_count(), num_threads))
else:
self.num_threads_ = numba.get_num_threads()
[docs] def cost_and_grad_kernel(
self,
log_probs: torch.Tensor,
grad: torch.Tensor,
labels: torch.Tensor,
mb: int,
T: int,
U: int,
bytes_used: int,
):
idx = CpuRNNT_index(
U, self.maxU_, self.minibatch_, self.alphabet_size_, self.batch_first
)
rnntm = CpuRNNT_metadata(
T, U, self.workspace, bytes_used, self.blank_, labels, log_probs, idx
)
if self.batch_first:
# zero grads
grad *= 0.0
llForward = self.compute_alphas(rnntm.log_probs2, T, U, rnntm.alphas)
llBackward = self.compute_betas_and_grads(
grad, rnntm.log_probs2, T, U, rnntm.alphas, rnntm.betas, labels, llForward
)
# Scale llForward by FastEmit lambda
llForward *= 1.0 + self.fastemit_lambda_
llBackward *= 1.0 + self.fastemit_lambda_
diff = (llForward - llBackward).abs()
if diff > 0.1:
print(f"WARNING: Forward backward likelihood mismatch : {diff}")
return -llForward
[docs] def compute_alphas(
self, log_probs: torch.Tensor, T: int, U: int, alphas: torch.Tensor
):
"""
Compute the probability of the forward variable alpha.
Args:
log_probs: Flattened tensor [B, T, U, V+1]
T: Length of the acoustic sequence T (not padded).
U: Length of the target sequence U (not padded).
alphas: Working space memory for alpha of shape [B, T, U].
Returns:
Loglikelihood of the forward variable alpha.
"""
idx = CpuRNNT_index(
U, self.maxU_, self.minibatch_, self.alphabet_size_, self.batch_first
)
alphas[0] = 0
for t in range(T):
for u in range(U):
if u == 0 and t > 0:
alphas[idx(t, 0)] = (
alphas[idx(t - 1, 0)] + log_probs[idx(t - 1, 0) * 2]
)
if t == 0 and u > 0:
alphas[idx(0, u)] = (
alphas[idx(0, u - 1)] + log_probs[idx(0, u - 1) * 2 + 1]
)
if t > 0 and u > 0:
no_emit = alphas[idx(t - 1, u)] + log_probs[idx(t - 1, u) * 2]
emit = alphas[idx(t, u - 1)] + log_probs[idx(t, u - 1) * 2 + 1]
alphas[idx(t, u)] = log_sum_exp(emit, no_emit)
loglike = alphas[idx(T - 1, U - 1)] + log_probs[idx(T - 1, U - 1) * 2]
return loglike
[docs] def compute_betas_and_grads(
self,
grad: torch.Tensor,
log_probs: torch.Tensor,
T: int,
U: int,
alphas: torch.Tensor,
betas: torch.Tensor,
labels: torch.Tensor,
logll: torch.Tensor,
):
"""
Compute backward variable beta as well as gradients of the activation
matrix wrt loglikelihood of forward variable.
Args:
grad: Working space memory of flattened shape [B, T, U, V+1]
log_probs: Activatio tensor of flattented shape [B, T, U, V+1]
T: Length of the acoustic sequence T (not padded).
U: Length of the target sequence U (not padded).
alphas: Working space memory for alpha of shape [B, T, U].
betas: Working space memory for alpha of shape [B, T, U].
labels: Ground truth label of shape [B, U]
logll: Loglikelihood of the forward variable.
Returns:
Loglikelihood of the forward variable and inplace updates the grad tensor.
"""
idx = CpuRNNT_index(
U, self.maxU_, self.minibatch_, self.alphabet_size_, self.batch_first
)
betas[idx(T - 1, U - 1)] = log_probs[idx(T - 1, U - 1) * 2]
for t in range(T - 1, -1, -1):
for u in range(U - 1, -1, -1):
if (u == U - 1) and (t < T - 1):
betas[idx(t, U - 1)] = (
betas[idx(t + 1, U - 1)] + log_probs[idx(t, U - 1) * 2]
)
if (t == T - 1) and (u < U - 1):
betas[idx(T - 1, u)] = (
betas[idx(T - 1, u + 1)] + log_probs[idx(T - 1, u) * 2 + 1]
)
if (t < T - 1) and (u < U - 1):
no_emit = betas[idx(t + 1, u)] + log_probs[idx(t, u) * 2]
emit = betas[idx(t, u + 1)] + log_probs[idx(t, u) * 2 + 1]
betas[idx(t, u)] = log_sum_exp(emit, no_emit)
loglike = betas[0]
# // Gradients w.r.t. log probabilities
for t in range(T):
for u in range(U):
if t < T - 1:
g = alphas[idx(t, u)] + betas[idx(t + 1, u)]
grad[idx(t, u, self.blank_)] = -torch.exp(
log_probs[idx(t, u) * 2] + g - loglike
)
if u < U - 1:
g = alphas[idx(t, u)] + betas[idx(t, u + 1)]
grad[idx(t, u, labels[u])] = -torch.exp(
math.log1p(self.fastemit_lambda_)
+ log_probs[idx(t, u) * 2 + 1]
+ g
- loglike
)
# // gradient to the last blank transition
grad[idx(T - 1, U - 1, self.blank_)] = -torch.exp(
log_probs[idx(T - 1, U - 1) * 2] + alphas[idx(T - 1, U - 1)] - loglike
)
return loglike
[docs] def cost_and_grad(
self,
log_probs: torch.Tensor,
grads: torch.Tensor,
costs: torch.Tensor,
flat_labels: torch.Tensor,
label_lengths: torch.Tensor,
input_lengths: torch.Tensor,
) -> global_constants.RNNTStatus:
# // per minibatch memory
per_minibatch_bytes = 0
# // alphas & betas
per_minibatch_bytes += self.maxT_ * self.maxU_ * 2
# // blank & label log probability cache
per_minibatch_bytes += self.maxT_ * self.maxU_ * 2
for mb in range(self.minibatch_):
T = input_lengths[mb] # // Length of utterance (time)
U = label_lengths[mb] + 1 # // Number of labels in transcription
batch_size = self.alphabet_size_
if self.batch_first:
batch_size = self.maxT_ * self.maxU_ * self.alphabet_size_
costs[mb] = self.cost_and_grad_kernel(
log_probs[(mb * batch_size) :],
grads[(mb * batch_size) :],
flat_labels[(mb * (self.maxU_ - 1)) :],
mb,
T,
U,
mb * per_minibatch_bytes,
)
return global_constants.RNNTStatus.RNNT_STATUS_SUCCESS
[docs] def score_forward(
self,
log_probs: torch.Tensor,
costs: torch.Tensor,
flat_labels: torch.Tensor,
label_lengths: torch.Tensor,
input_lengths: torch.Tensor,
):
# // per minibatch memory
per_minibatch_bytes = 0
# // alphas & betas
per_minibatch_bytes += self.maxT_ * self.maxU_ * 2
# // blank & label log probability cache
per_minibatch_bytes += self.maxT_ * self.maxU_ * 2
for mb in range(self.minibatch_):
T = input_lengths[mb] # // Length of utterance (time)
U = label_lengths[mb] + 1 # // Number of labels in transcription
batch_size = self.alphabet_size_
if self.batch_first:
batch_size = self.maxT_ * self.maxU_ * self.alphabet_size_
idx = CpuRNNT_index(
U, self.maxU_, self.minibatch_, self.alphabet_size_, self.batch_first
)
rnntm = CpuRNNT_metadata(
T,
U,
self.workspace,
mb * per_minibatch_bytes,
self.blank_,
flat_labels[(mb * (self.maxU_ - 1)) :],
log_probs[(mb * batch_size) :],
idx,
)
costs[mb] = -self.compute_alphas(rnntm.log_probs2, T, U, rnntm.alphas)
return global_constants.RNNTStatus.RNNT_STATUS_SUCCESS