Source code for espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright 2018-2019, Mingkun Huang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from torch.autograd import Function
from torch.nn import Module

from espnet2.asr.transducer.rnnt_multi_blank import rnnt
from espnet2.asr.transducer.rnnt_multi_blank.utils.cpu_utils import cpu_rnnt

__all__ = ["rnnt_loss", "RNNTLossNumba", "MultiblankRNNTLossNumba"]


class _RNNTNumba(Function):
    @staticmethod
    def forward(
        ctx,
        acts,
        labels,
        act_lens,
        label_lens,
        blank,
        reduction,
        fastemit_lambda,
        clamp,
    ):
        """
        log_probs: Tensor of (batch x seqLength x labelLength x outputDim)
            containing output from network
        labels: 2 dimensional Tensor containing all the targets of
            the batch with zero padded
        act_lens: Tensor of size (batch) containing size of each
            output sequence from the network
        label_lens: Tensor of (batch) containing label length of each example
        fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
            FastEmit: Low-latency Streaming ASR with Sequence-level
            Emission Regularization.
        """

        is_cuda = acts.is_cuda

        certify_inputs(acts, labels, act_lens, label_lens)
        if clamp < 0:
            raise ValueError("`clamp` must be 0.0 or positive float value.")

        loss_func = rnnt.rnnt_loss_gpu if is_cuda else rnnt.rnnt_loss_cpu
        grads = torch.zeros_like(acts) if acts.requires_grad else None
        minibatch_size = acts.size(0)
        costs = torch.zeros(minibatch_size, device=acts.device, dtype=acts.dtype)

        loss_func(
            acts,
            labels=labels,
            input_lengths=act_lens,
            label_lengths=label_lens,
            costs=costs,
            grads=grads,
            blank_label=blank,
            fastemit_lambda=fastemit_lambda,
            clamp=clamp,
            num_threads=0,
        )

        if reduction in ["sum", "mean"]:
            costs = costs.sum().unsqueeze_(-1)
            if reduction == "mean":
                costs /= minibatch_size

                if grads is not None:
                    grads /= minibatch_size

        ctx.grads = grads

        return costs

    @staticmethod
    def backward(ctx, grad_output):
        if grad_output is not None and ctx.grads is not None:
            grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
            return ctx.grads.mul_(grad_output), None, None, None, None, None, None, None


class _MultiblankRNNTNumba(Function):
    """
    Numba class for multi-blank transducer loss (https://arxiv.org/pdf/2211.03541.pdf)
    """

    @staticmethod
    def forward(
        ctx,
        acts,
        labels,
        act_lens,
        label_lens,
        blank,
        big_blank_durations,
        reduction,
        fastemit_lambda,
        clamp,
        sigma,
    ):
        """
        big_blank_durations: list of durations for multi-blank transducer, e.g.
            [2, 4, 8].
        sigma: hyper-parameter for logit under-normalization method for training
            multi-blank transducers. Recommended value 0.05.
        Refer to https://arxiv.org/pdf/2211.03541 for detailed explanations for
            the above parameters;
        For other parameters for this class, refer to comment for class _RNNTNumba
        """

        is_cuda = acts.is_cuda

        certify_inputs(acts, labels, act_lens, label_lens)
        if clamp < 0:
            raise ValueError("`clamp` must be 0.0 or positive float value.")

        if is_cuda:
            loss_func = rnnt.multiblank_rnnt_loss_gpu
        else:
            raise NotImplementedError()

        grads = torch.zeros_like(acts) if acts.requires_grad else None
        minibatch_size = acts.size(0)
        costs = torch.zeros(minibatch_size, device=acts.device, dtype=acts.dtype)

        loss_func(
            acts,
            labels=labels,
            input_lengths=act_lens,
            label_lengths=label_lens,
            costs=costs,
            grads=grads,
            blank_label=blank,
            big_blank_durations=big_blank_durations,
            fastemit_lambda=fastemit_lambda,
            clamp=clamp,
            sigma=sigma,
            num_threads=0,
        )

        if reduction in ["sum", "mean"]:
            costs = costs.sum().unsqueeze_(-1)
            if reduction == "mean":
                costs /= minibatch_size

                if grads is not None:
                    grads /= minibatch_size

        ctx.grads = grads

        return costs

    @staticmethod
    def backward(ctx, grad_output):
        if grad_output is not None and ctx.grads is not None:
            grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
            return (
                ctx.grads.mul_(grad_output),
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
            )


[docs]def rnnt_loss( acts, labels, act_lens, label_lens, blank=0, reduction="mean", fastemit_lambda: float = 0.0, clamp: float = 0.0, ): """RNN Transducer Loss (functional form) Args: acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network labels: 2 dimensional Tensor containing all the targets of the batch with zero padded act_lens: Tensor of size (batch) containing size of each output sequence from the network label_lens: Tensor of (batch) containing label length of each example blank (int, optional): blank label. Default: 0. reduction (string, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the output losses will be divided by the target lengths and then the mean over the batch is taken. Default: 'mean' """ if not acts.is_cuda: # Since CPU requires log_softmax to be computed explicitly, # we need to perform grad clipping # *after* we have obtained the gradients of loss(logsoftmax()). # This is highly wasteful since it requires a copy of the entire joint # tensor which is expensive. CUDA version is much more efficient since # it performs an inplace logsoftmax, and therefore # can inplace clamp the gradient. if clamp > 0.0: acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, clamp) # NOTE manually done log_softmax for CPU version, # log_softmax is computed within GPU version. acts = torch.nn.functional.log_softmax(acts, -1) return _RNNTNumba.apply( acts, labels, act_lens, label_lens, blank, reduction, fastemit_lambda, clamp )
def multiblank_rnnt_loss( acts, labels, act_lens, label_lens, blank, big_blank_durations=[], reduction="mean", fastemit_lambda: float = 0.0, clamp: float = 0.0, ): """ Multi-blank RNN Transducer (https://arxiv.org/pdf/2211.03541.pdf) Loss (functional form) Args: acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network labels: 2 dimensional Tensor containing all the targets of the batch with zero padded act_lens: Tensor of size (batch) containing size of each output sequence from the network label_lens: Tensor of (batch) containing label length of each example blank (int): standard blank label. big_blank_durations: list of durations for multi-blank transducer, e.g. [2, 4, 8]. sigma: hyper-parameter for logit under-normalization method for training multi-blank transducers. Recommended value 0.05. Refer to https://arxiv.org/pdf/2211.03541 for detailed explanations for the last two params. reduction (string, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the output losses will be divided by the target lengths and then the mean over the batch is taken. Default: 'mean' """ if not acts.is_cuda: # Since CPU requires log_softmax to be computed explicitly, # we need to perform grad clipping # *after* we have obtained the gradients of loss(logsoftmax()). # This is highly wasteful since it requires a copy of the entire # joint tensor which is expensive. # CUDA version is much more efficient since it performs an inplace # logsoftmax, and therefore can inplace clamp the gradient. if clamp > 0.0: acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, clamp) # NOTE manually done log_softmax for CPU version, # log_softmax is computed within GPU version. acts = torch.nn.functional.log_softmax(acts, -1) return _MultiblankRNNTNumba.apply( acts, labels, act_lens, label_lens, blank, big_blank_durations, reduction, fastemit_lambda, clamp, )
[docs]class RNNTLossNumba(Module): """ Parameters: blank (int, optional): blank label. Default: 0. reduction (string, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the output losses will be divided by the target lengths and then the mean over the batch is taken. Default: 'mean' fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. """ def __init__( self, blank=0, reduction="mean", fastemit_lambda: float = 0.0, clamp: float = -1 ): super(RNNTLossNumba, self).__init__() self.blank = blank self.fastemit_lambda = fastemit_lambda self.clamp = float(clamp) if clamp > 0 else 0.0 self.reduction = reduction self.loss = _RNNTNumba.apply
[docs] def forward(self, acts, labels, act_lens, label_lens): """ log_probs: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network labels: 2 dimensional Tensor containing all the targets of the batch with zero padded act_lens: Tensor of size (batch) containing size of each output sequence from the network label_lens: Tensor of (batch) containing label length of each example """ if not acts.is_cuda: # Since CPU requires log_softmax to be computed explicitly, # we need to perform grad clipping # *after* we have obtained the gradients of loss(logsoftmax()). # This is highly wasteful since it requires a copy of the entire # joint tensor which is expensive. # CUDA version is much more efficient since it performs an # inplace logsoftmax, and therefore can inplace clamp the gradient. if self.clamp > 0.0: acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, self.clamp) # NOTE manually done log_softmax for CPU version, # log_softmax is computed within GPU version. acts = torch.nn.functional.log_softmax(acts, -1) return self.loss( acts, labels, act_lens, label_lens, self.blank, self.reduction, self.fastemit_lambda, self.clamp, )
[docs]class MultiblankRNNTLossNumba(Module): """ Parameters: blank (int): standard blank label. big_blank_durations: list of durations for multi-blank transducer, e.g. [2, 4, 8]. sigma: hyper-parameter for logit under-normalization method for training multi-blank transducers. Recommended value 0.05. Refer to https://arxiv.org/pdf/2211.03541 for detailed explanations for the above parameters; reduction (string, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the output losses will be divided by the target lengths and then the mean over the batch is taken. Default: 'mean' fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. """ def __init__( self, blank, big_blank_durations, reduction="mean", fastemit_lambda: float = 0.0, clamp: float = -1, sigma: float = 0.0, ): super(MultiblankRNNTLossNumba, self).__init__() self.blank = blank self.big_blank_durations = big_blank_durations self.fastemit_lambda = fastemit_lambda self.clamp = float(clamp) if clamp > 0 else 0.0 self.reduction = reduction self.loss = _MultiblankRNNTNumba.apply self.sigma = sigma
[docs] def forward(self, acts, labels, act_lens, label_lens): """ log_probs: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network labels: 2 dimensional Tensor containing all the targets of the batch with zero padded act_lens: Tensor of size (batch) containing size of each output sequence from the network label_lens: Tensor of (batch) containing label length of each example """ if not acts.is_cuda: # Since CPU requires log_softmax to be computed explicitly, # we need to perform grad clipping # *after* we have obtained the gradients of loss(logsoftmax()). # This is highly wasteful since it requires a copy of the entire # joint tensor which is expensive. # CUDA version is much more efficient since it performs an # inplace logsoftmax, and therefore can inplace clamp the gradient. if self.clamp > 0.0: acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, self.clamp) # NOTE manually done log_softmax for CPU version, # log_softmax is computed within GPU version. acts = torch.nn.functional.log_softmax(acts, -1) return self.loss( acts, labels, act_lens, label_lens, self.blank, self.big_blank_durations, self.reduction, self.fastemit_lambda, self.clamp, self.sigma, )
def check_type(var, t, name): if var.dtype is not t: raise TypeError("{} must be {}".format(name, t)) def check_contiguous(var, name): if not var.is_contiguous(): raise ValueError("{} must be contiguous".format(name)) def check_dim(var, dim, name): if len(var.shape) != dim: raise ValueError("{} must be {}D".format(name, dim)) def certify_inputs(log_probs, labels, lengths, label_lengths): # check_type(log_probs, torch.float32, "log_probs") check_type(labels, torch.int32, "labels") check_type(label_lengths, torch.int32, "label_lengths") check_type(lengths, torch.int32, "lengths") check_contiguous(log_probs, "log_probs") check_contiguous(labels, "labels") check_contiguous(label_lengths, "label_lengths") check_contiguous(lengths, "lengths") if lengths.shape[0] != log_probs.shape[0]: raise ValueError( f"Must have a length per example. " f"Given lengths dim: {lengths.shape[0]}, " f"Log probs dim : {log_probs.shape[0]}" ) if label_lengths.shape[0] != log_probs.shape[0]: raise ValueError( "Must have a label length per example. " f"Given label lengths dim : {label_lengths.shape[0]}, " f"Log probs dim : {log_probs.shape[0]}" ) check_dim(log_probs, 4, "log_probs") check_dim(labels, 2, "labels") check_dim(lengths, 1, "lenghts") check_dim(label_lengths, 1, "label_lenghts") max_T = torch.max(lengths) max_U = torch.max(label_lengths) T, U = log_probs.shape[1:3] if T != max_T: raise ValueError( f"Input length mismatch! Given T: {T}, Expected max T from input \ lengths: {max_T}" ) if U != max_U + 1: raise ValueError( f"Output length mismatch! Given U: {U}, Expected max U from target \ lengths: {max_U} + 1" )