# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright 2018-2019, Mingkun Huang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.autograd import Function
from torch.nn import Module
from espnet2.asr.transducer.rnnt_multi_blank import rnnt
from espnet2.asr.transducer.rnnt_multi_blank.utils.cpu_utils import cpu_rnnt
__all__ = ["rnnt_loss", "RNNTLossNumba", "MultiblankRNNTLossNumba"]
class _RNNTNumba(Function):
@staticmethod
def forward(
ctx,
acts,
labels,
act_lens,
label_lens,
blank,
reduction,
fastemit_lambda,
clamp,
):
"""
log_probs: Tensor of (batch x seqLength x labelLength x outputDim)
containing output from network
labels: 2 dimensional Tensor containing all the targets of
the batch with zero padded
act_lens: Tensor of size (batch) containing size of each
output sequence from the network
label_lens: Tensor of (batch) containing label length of each example
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level
Emission Regularization.
"""
is_cuda = acts.is_cuda
certify_inputs(acts, labels, act_lens, label_lens)
if clamp < 0:
raise ValueError("`clamp` must be 0.0 or positive float value.")
loss_func = rnnt.rnnt_loss_gpu if is_cuda else rnnt.rnnt_loss_cpu
grads = torch.zeros_like(acts) if acts.requires_grad else None
minibatch_size = acts.size(0)
costs = torch.zeros(minibatch_size, device=acts.device, dtype=acts.dtype)
loss_func(
acts,
labels=labels,
input_lengths=act_lens,
label_lengths=label_lens,
costs=costs,
grads=grads,
blank_label=blank,
fastemit_lambda=fastemit_lambda,
clamp=clamp,
num_threads=0,
)
if reduction in ["sum", "mean"]:
costs = costs.sum().unsqueeze_(-1)
if reduction == "mean":
costs /= minibatch_size
if grads is not None:
grads /= minibatch_size
ctx.grads = grads
return costs
@staticmethod
def backward(ctx, grad_output):
if grad_output is not None and ctx.grads is not None:
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul_(grad_output), None, None, None, None, None, None, None
class _MultiblankRNNTNumba(Function):
"""
Numba class for multi-blank transducer loss (https://arxiv.org/pdf/2211.03541.pdf)
"""
@staticmethod
def forward(
ctx,
acts,
labels,
act_lens,
label_lens,
blank,
big_blank_durations,
reduction,
fastemit_lambda,
clamp,
sigma,
):
"""
big_blank_durations: list of durations for multi-blank transducer, e.g.
[2, 4, 8].
sigma: hyper-parameter for logit under-normalization method for training
multi-blank transducers. Recommended value 0.05.
Refer to https://arxiv.org/pdf/2211.03541 for detailed explanations for
the above parameters;
For other parameters for this class, refer to comment for class _RNNTNumba
"""
is_cuda = acts.is_cuda
certify_inputs(acts, labels, act_lens, label_lens)
if clamp < 0:
raise ValueError("`clamp` must be 0.0 or positive float value.")
if is_cuda:
loss_func = rnnt.multiblank_rnnt_loss_gpu
else:
raise NotImplementedError()
grads = torch.zeros_like(acts) if acts.requires_grad else None
minibatch_size = acts.size(0)
costs = torch.zeros(minibatch_size, device=acts.device, dtype=acts.dtype)
loss_func(
acts,
labels=labels,
input_lengths=act_lens,
label_lengths=label_lens,
costs=costs,
grads=grads,
blank_label=blank,
big_blank_durations=big_blank_durations,
fastemit_lambda=fastemit_lambda,
clamp=clamp,
sigma=sigma,
num_threads=0,
)
if reduction in ["sum", "mean"]:
costs = costs.sum().unsqueeze_(-1)
if reduction == "mean":
costs /= minibatch_size
if grads is not None:
grads /= minibatch_size
ctx.grads = grads
return costs
@staticmethod
def backward(ctx, grad_output):
if grad_output is not None and ctx.grads is not None:
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return (
ctx.grads.mul_(grad_output),
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
[docs]def rnnt_loss(
acts,
labels,
act_lens,
label_lens,
blank=0,
reduction="mean",
fastemit_lambda: float = 0.0,
clamp: float = 0.0,
):
"""RNN Transducer Loss (functional form)
Args:
acts: Tensor of (batch x seqLength x labelLength x outputDim)
containing output from network
labels: 2 dimensional Tensor containing all the targets of
the batch with zero padded
act_lens: Tensor of size (batch) containing size of each
output sequence from the network
label_lens: Tensor of (batch) containing label length of each example
blank (int, optional): blank label. Default: 0.
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the output losses will be divided by the target lengths and
then the mean over the batch is taken. Default: 'mean'
"""
if not acts.is_cuda:
# Since CPU requires log_softmax to be computed explicitly,
# we need to perform grad clipping
# *after* we have obtained the gradients of loss(logsoftmax()).
# This is highly wasteful since it requires a copy of the entire joint
# tensor which is expensive. CUDA version is much more efficient since
# it performs an inplace logsoftmax, and therefore
# can inplace clamp the gradient.
if clamp > 0.0:
acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, clamp)
# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts = torch.nn.functional.log_softmax(acts, -1)
return _RNNTNumba.apply(
acts, labels, act_lens, label_lens, blank, reduction, fastemit_lambda, clamp
)
def multiblank_rnnt_loss(
acts,
labels,
act_lens,
label_lens,
blank,
big_blank_durations=[],
reduction="mean",
fastemit_lambda: float = 0.0,
clamp: float = 0.0,
):
"""
Multi-blank RNN Transducer (https://arxiv.org/pdf/2211.03541.pdf)
Loss (functional form)
Args:
acts: Tensor of (batch x seqLength x labelLength x outputDim) containing
output from network
labels: 2 dimensional Tensor containing all the targets of the batch with
zero padded
act_lens: Tensor of size (batch) containing size of each output
sequence from the network
label_lens: Tensor of (batch) containing label length of each example
blank (int): standard blank label.
big_blank_durations: list of durations for multi-blank transducer, e.g.
[2, 4, 8].
sigma: hyper-parameter for logit under-normalization method for training
multi-blank transducers. Recommended value 0.05.
Refer to https://arxiv.org/pdf/2211.03541 for detailed explanations for
the last two params.
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the output losses will be divided by the target lengths and
then the mean over the batch is taken. Default: 'mean'
"""
if not acts.is_cuda:
# Since CPU requires log_softmax to be computed explicitly,
# we need to perform grad clipping
# *after* we have obtained the gradients of loss(logsoftmax()).
# This is highly wasteful since it requires a copy of the entire
# joint tensor which is expensive.
# CUDA version is much more efficient since it performs an inplace
# logsoftmax, and therefore can inplace clamp the gradient.
if clamp > 0.0:
acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, clamp)
# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts = torch.nn.functional.log_softmax(acts, -1)
return _MultiblankRNNTNumba.apply(
acts,
labels,
act_lens,
label_lens,
blank,
big_blank_durations,
reduction,
fastemit_lambda,
clamp,
)
[docs]class RNNTLossNumba(Module):
"""
Parameters:
blank (int, optional): blank label. Default: 0.
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the output losses will be divided by the target lengths and
then the mean over the batch is taken. Default: 'mean'
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level
Emission Regularization.
clamp: Float value. When set to value >= 0.0, will clamp the
gradient to [-clamp, clamp].
"""
def __init__(
self, blank=0, reduction="mean", fastemit_lambda: float = 0.0, clamp: float = -1
):
super(RNNTLossNumba, self).__init__()
self.blank = blank
self.fastemit_lambda = fastemit_lambda
self.clamp = float(clamp) if clamp > 0 else 0.0
self.reduction = reduction
self.loss = _RNNTNumba.apply
[docs] def forward(self, acts, labels, act_lens, label_lens):
"""
log_probs: Tensor of (batch x seqLength x labelLength x outputDim)
containing output from network
labels: 2 dimensional Tensor containing all the targets of the
batch with zero padded
act_lens: Tensor of size (batch) containing size of each output
sequence from the network
label_lens: Tensor of (batch) containing label length of each example
"""
if not acts.is_cuda:
# Since CPU requires log_softmax to be computed explicitly,
# we need to perform grad clipping
# *after* we have obtained the gradients of loss(logsoftmax()).
# This is highly wasteful since it requires a copy of the entire
# joint tensor which is expensive.
# CUDA version is much more efficient since it performs an
# inplace logsoftmax, and therefore can inplace clamp the gradient.
if self.clamp > 0.0:
acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, self.clamp)
# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts = torch.nn.functional.log_softmax(acts, -1)
return self.loss(
acts,
labels,
act_lens,
label_lens,
self.blank,
self.reduction,
self.fastemit_lambda,
self.clamp,
)
[docs]class MultiblankRNNTLossNumba(Module):
"""
Parameters:
blank (int): standard blank label.
big_blank_durations: list of durations for multi-blank transducer, e.g.
[2, 4, 8].
sigma: hyper-parameter for logit under-normalization method for training
multi-blank transducers. Recommended value 0.05.
Refer to https://arxiv.org/pdf/2211.03541 for detailed explanations for
the above parameters;
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the output losses will be divided by the target lengths and
then the mean over the batch is taken. Default: 'mean'
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level
Emission Regularization.
clamp: Float value. When set to value >= 0.0, will clamp the
gradient to [-clamp, clamp].
"""
def __init__(
self,
blank,
big_blank_durations,
reduction="mean",
fastemit_lambda: float = 0.0,
clamp: float = -1,
sigma: float = 0.0,
):
super(MultiblankRNNTLossNumba, self).__init__()
self.blank = blank
self.big_blank_durations = big_blank_durations
self.fastemit_lambda = fastemit_lambda
self.clamp = float(clamp) if clamp > 0 else 0.0
self.reduction = reduction
self.loss = _MultiblankRNNTNumba.apply
self.sigma = sigma
[docs] def forward(self, acts, labels, act_lens, label_lens):
"""
log_probs: Tensor of (batch x seqLength x labelLength x outputDim)
containing output from network
labels: 2 dimensional Tensor containing all the targets of
the batch with zero padded
act_lens: Tensor of size (batch) containing size of each output
sequence from the network
label_lens: Tensor of (batch) containing label length of each example
"""
if not acts.is_cuda:
# Since CPU requires log_softmax to be computed explicitly,
# we need to perform grad clipping
# *after* we have obtained the gradients of loss(logsoftmax()).
# This is highly wasteful since it requires a copy of the entire
# joint tensor which is expensive.
# CUDA version is much more efficient since it performs an
# inplace logsoftmax, and therefore can inplace clamp the gradient.
if self.clamp > 0.0:
acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, self.clamp)
# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts = torch.nn.functional.log_softmax(acts, -1)
return self.loss(
acts,
labels,
act_lens,
label_lens,
self.blank,
self.big_blank_durations,
self.reduction,
self.fastemit_lambda,
self.clamp,
self.sigma,
)
def check_type(var, t, name):
if var.dtype is not t:
raise TypeError("{} must be {}".format(name, t))
def check_contiguous(var, name):
if not var.is_contiguous():
raise ValueError("{} must be contiguous".format(name))
def check_dim(var, dim, name):
if len(var.shape) != dim:
raise ValueError("{} must be {}D".format(name, dim))
def certify_inputs(log_probs, labels, lengths, label_lengths):
# check_type(log_probs, torch.float32, "log_probs")
check_type(labels, torch.int32, "labels")
check_type(label_lengths, torch.int32, "label_lengths")
check_type(lengths, torch.int32, "lengths")
check_contiguous(log_probs, "log_probs")
check_contiguous(labels, "labels")
check_contiguous(label_lengths, "label_lengths")
check_contiguous(lengths, "lengths")
if lengths.shape[0] != log_probs.shape[0]:
raise ValueError(
f"Must have a length per example. "
f"Given lengths dim: {lengths.shape[0]}, "
f"Log probs dim : {log_probs.shape[0]}"
)
if label_lengths.shape[0] != log_probs.shape[0]:
raise ValueError(
"Must have a label length per example. "
f"Given label lengths dim : {label_lengths.shape[0]}, "
f"Log probs dim : {log_probs.shape[0]}"
)
check_dim(log_probs, 4, "log_probs")
check_dim(labels, 2, "labels")
check_dim(lengths, 1, "lenghts")
check_dim(label_lengths, 1, "label_lenghts")
max_T = torch.max(lengths)
max_U = torch.max(label_lengths)
T, U = log_probs.shape[1:3]
if T != max_T:
raise ValueError(
f"Input length mismatch! Given T: {T}, Expected max T from input \
lengths: {max_T}"
)
if U != max_U + 1:
raise ValueError(
f"Output length mismatch! Given U: {U}, Expected max U from target \
lengths: {max_U} + 1"
)