Source code for espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils.gpu_rnnt

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright 2018-2019, Mingkun Huang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing
from typing import Optional, Tuple

import numba
import torch
from numba import cuda

from espnet2.asr.transducer.rnnt_multi_blank.utils import global_constants, rnnt_helper
from espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils import (
    gpu_rnnt_kernel,
    reduce,
)


[docs]class GPURNNT: def __init__( self, minibatch: int, maxT: int, maxU: int, alphabet_size: int, workspace, blank: int, fastemit_lambda: float, clamp: float, num_threads: int, stream, ): """ Helper class to launch the CUDA Kernels to compute the Transducer Loss. Args: minibatch: Int representing the batch size. maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). workspace: An allocated chunk of memory that will be sliced off and reshaped into required blocks used as working memory. blank: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. num_threads: Number of OMP threads to launch. stream: Numba Cuda Stream. """ self.minibatch_ = minibatch self.maxT_ = maxT self.maxU_ = maxU self.alphabet_size_ = alphabet_size self.gpu_workspace = cuda.as_cuda_array( workspace ) # a flat vector of floatX numbers that represents allocated memory slices self.blank_ = blank self.fastemit_lambda_ = fastemit_lambda self.clamp_ = abs(clamp) self.num_threads_ = num_threads self.stream_ = stream # type: cuda.cudadrv.driver.Stream if num_threads > 0: numba.set_num_threads(min(multiprocessing.cpu_count(), num_threads)) self.num_threads_ = numba.get_num_threads() else: self.num_threads_ = numba.get_num_threads()
[docs] def log_softmax(self, acts: torch.Tensor, denom: torch.Tensor): """ Computes the log softmax denominator of the input activation tensor and stores the result in denom. Args: acts: Activation tensor of shape [B, T, U, V+1]. The input must be represented as a flat tensor of shape [B * T * U * (V+1)] to allow pointer indexing. denom: A zero tensor of same shape as acts. Updates: This kernel inplace updates the `denom` tensor """ # // trans_acts + pred_acts -> log_softmax denominator reduce.reduce_max( acts, denom, rows=self.alphabet_size_, cols=self.minibatch_ * self.maxT_ * self.maxU_, minus=False, stream=self.stream_, ) reduce.reduce_exp( acts, denom, rows=self.alphabet_size_, cols=self.minibatch_ * self.maxT_ * self.maxU_, minus=True, stream=self.stream_, )
[docs] def compute_cost_and_score( self, acts: torch.Tensor, grads: Optional[torch.Tensor], costs: torch.Tensor, labels: torch.Tensor, label_lengths: torch.Tensor, input_lengths: torch.Tensor, ) -> global_constants.RNNTStatus: """ Compute both the loss and the gradients. Args: acts: A flattened tensor of shape [B, T, U, V+1] representing the activation matrix. grad: A flattented zero tensor of same shape as acts. costs: A zero vector of length B which will be updated inplace with the log probability costs. flat_labels: A flattened matrix of labels of shape [B, U] label_lengths: A vector of length B that contains the original lengths of the acoustic sequence. input_lengths: A vector of length B that contains the original lengths of the target sequence. Updates: This will launch kernels that will update inline the following variables: - grads: Gradients of the activation matrix wrt the costs vector. - costs: Negative log likelihood of the forward variable. Returns: An enum that either represents a successful RNNT operation or failure. """ training = grads is not None if training: grads *= 0.0 # zero grads used_offset, ( denom, alphas, betas, llForward, llBackward, ) = self._prepare_workspace() # START EXECUTION self.log_softmax(acts, denom) # Compute alphas gpu_rnnt_kernel.compute_alphas_kernel[ self.minibatch_, self.maxU_, self.stream_, 0 ]( acts, denom, alphas, llForward, input_lengths, label_lengths, labels, self.minibatch_, self.maxT_, self.maxU_, self.alphabet_size_, self.blank_, ) if training: # Compute betas gpu_rnnt_kernel.compute_betas_kernel[ self.minibatch_, self.maxU_, self.stream_, 0 ]( acts, denom, betas, llBackward, input_lengths, label_lengths, labels, self.minibatch_, self.maxT_, self.maxU_, self.alphabet_size_, self.blank_, ) # Compute gradient grad_blocks_per_grid = self.minibatch_ * self.maxT_ * self.maxU_ grad_threads_per_block = gpu_rnnt_kernel.GPU_RNNT_THREAD_SIZE gpu_rnnt_kernel.compute_grad_kernel[ grad_blocks_per_grid, grad_threads_per_block, self.stream_, 0 ]( grads, acts, denom, alphas, betas, llForward, input_lengths, label_lengths, labels, self.minibatch_, self.maxT_, self.maxU_, self.alphabet_size_, self.blank_, self.fastemit_lambda_, self.clamp_, ) # // cost copy, negate (for log likelihood) and update with additional # regularizers This needs to be done via CUDA, because we used temporary # memory llForward passed to alpha, which was updated with log likelihoods. # But copying this data into a pytorch pointer is more difficult # (numba api is one way) # Therefore launch a pointwise CUDA kernel to update the costs inplace # from data of llForward then negate to compute the loglikelihood. threadsperblock = min(costs.shape[0], 32) blockspergrid = (costs.shape[0] + (threadsperblock - 1)) // threadsperblock rnnt_helper.compute_costs_data[blockspergrid, threadsperblock, self.stream_, 0]( llForward, costs, self.fastemit_lambda_ ) self.stream_.synchronize() return global_constants.RNNTStatus.RNNT_STATUS_SUCCESS
[docs] def cost_and_grad( self, acts: torch.Tensor, grads: torch.Tensor, costs: torch.Tensor, pad_labels: torch.Tensor, label_lengths: torch.Tensor, input_lengths: torch.Tensor, ): if ( acts is None or grads is None or costs is None or pad_labels is None or label_lengths is None or input_lengths is None ): return global_constants.RNNTStatus.RNNT_STATUS_INVALID_VALUE return self.compute_cost_and_score( acts, grads, costs, pad_labels, label_lengths, input_lengths )
[docs] def score_forward( self, acts: torch.Tensor, costs: torch.Tensor, pad_labels: torch.Tensor, label_lengths: torch.Tensor, input_lengths: torch.Tensor, ): if ( acts is None or costs is None or pad_labels is None or label_lengths is None or input_lengths is None ): return global_constants.RNNTStatus.RNNT_STATUS_INVALID_VALUE return self.compute_cost_and_score( acts, None, costs, pad_labels, label_lengths, input_lengths )
def _prepare_workspace(self) -> Tuple[int, Tuple[torch.Tensor, ...]]: """ Helper method that uses the workspace and constructs slices of it that can be used. Returns: An int, representing the offset of the used workspace (practically, the slice of the workspace consumed) A tuple of tensors representing the shared workspace. """ used_offset = 0 # // denom denom = self.gpu_workspace[ used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_ ] used_offset += self.maxT_ * self.maxU_ * self.minibatch_ # // alphas & betas alphas = self.gpu_workspace[ used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_ ] used_offset += self.maxT_ * self.maxU_ * self.minibatch_ betas = self.gpu_workspace[ used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_ ] used_offset += self.maxT_ * self.maxU_ * self.minibatch_ # // logllh llForward = self.gpu_workspace[used_offset : used_offset + self.minibatch_] used_offset += self.minibatch_ llBackward = self.gpu_workspace[used_offset : used_offset + self.minibatch_] used_offset += self.minibatch_ return used_offset, (denom, alphas, betas, llForward, llBackward)
[docs]class MultiblankGPURNNT(GPURNNT): def __init__( self, sigma: float, num_big_blanks: int, minibatch: int, maxT: int, maxU: int, alphabet_size: int, workspace, big_blank_workspace, blank: int, fastemit_lambda: float, clamp: float, num_threads: int, stream, ): """ Helper class to launch the CUDA Kernels to compute Multi-blank Transducer Loss (https://arxiv.org/pdf/2211.03541). Args: sigma: Hyper-parameter related to the logit-normalization method in training multi-blank transducers. num_big_blanks: Number of big blank symbols the model has. This should not include the standard blank symbol. minibatch: Int representing the batch size. maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V + 1 + num-big-blanks workspace: An allocated chunk of memory that will be sliced off and reshaped into required blocks used as working memory. big_blank_workspace: An allocated chunk of memory that will be sliced off and reshaped into required blocks used as working memory specifically for the multi-blank related computations. blank: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. num_threads: Number of OMP threads to launch. stream: Numba Cuda Stream. """ super().__init__( minibatch, maxT, maxU, alphabet_size, workspace, blank, fastemit_lambda, clamp, num_threads, stream, ) self.big_blank_workspace = cuda.as_cuda_array( big_blank_workspace ) # a flat vector of integer numbers that represents allocated memory slices self.num_big_blanks = num_big_blanks self.sigma = sigma
[docs] def compute_cost_and_score( self, acts: torch.Tensor, grads: Optional[torch.Tensor], costs: torch.Tensor, labels: torch.Tensor, label_lengths: torch.Tensor, input_lengths: torch.Tensor, ) -> global_constants.RNNTStatus: """ Compute both the loss and the gradients. Args: acts: A flattened tensor of shape [B, T, U, V+1] representing the activation matrix. grad: A flattented zero tensor of same shape as acts. costs: A zero vector of length B which will be updated inplace with the log probability costs. flat_labels: A flattened matrix of labels of shape [B, U] label_lengths: A vector of length B that contains the original lengths of the acoustic sequence. input_lengths: A vector of length B that contains the original lengths of the target sequence. Updates: This will launch kernels that will update inline the following variables: - grads: Gradients of the activation matrix wrt the costs vector. - costs: Negative log likelihood of the forward variable. Returns: An enum that either represents a successful RNNT operation or failure. """ training = grads is not None if training: grads *= 0.0 # zero grads _, ( denom, alphas, betas, llForward, llBackward, bigblank_durations, ) = self._prepare_workspace() # START EXECUTION self.log_softmax(acts, denom) # Compute alphas gpu_rnnt_kernel.compute_multiblank_alphas_kernel[ self.minibatch_, self.maxU_, self.stream_, 0 ]( acts, denom, self.sigma, alphas, llForward, input_lengths, label_lengths, labels, self.minibatch_, self.maxT_, self.maxU_, self.alphabet_size_, self.blank_, bigblank_durations, self.num_big_blanks, ) if training: # Compute betas gpu_rnnt_kernel.compute_multiblank_betas_kernel[ self.minibatch_, self.maxU_, self.stream_, 0 ]( acts, denom, self.sigma, betas, llBackward, input_lengths, label_lengths, labels, self.minibatch_, self.maxT_, self.maxU_, self.alphabet_size_, self.blank_, bigblank_durations, self.num_big_blanks, ) # Compute gradient grad_blocks_per_grid = self.minibatch_ * self.maxT_ * self.maxU_ grad_threads_per_block = gpu_rnnt_kernel.GPU_RNNT_THREAD_SIZE gpu_rnnt_kernel.compute_multiblank_grad_kernel[ grad_blocks_per_grid, grad_threads_per_block, self.stream_, 0 ]( grads, acts, denom, self.sigma, alphas, betas, llForward, input_lengths, label_lengths, labels, self.minibatch_, self.maxT_, self.maxU_, self.alphabet_size_, self.blank_, bigblank_durations, self.num_big_blanks, self.fastemit_lambda_, self.clamp_, ) # // cost copy, negate (for log likelihood) and update with additional # regularizers. This needs to be done via CUDA, because we used temporary # memory llForward passed to alpha, which was updated with log likelihoods. # But copying this data into a pytorch pointer is more difficult # (numba api is one way) # Therefore launch a pointwise CUDA kernel to update the costs inplace # from data of llForward. Then negate to compute the loglikelihood. threadsperblock = min(costs.shape[0], 32) blockspergrid = (costs.shape[0] + (threadsperblock - 1)) // threadsperblock rnnt_helper.compute_costs_data[blockspergrid, threadsperblock, self.stream_, 0]( llForward, costs, self.fastemit_lambda_ ) self.stream_.synchronize() return global_constants.RNNTStatus.RNNT_STATUS_SUCCESS
[docs] def cost_and_grad( self, acts: torch.Tensor, grads: torch.Tensor, costs: torch.Tensor, pad_labels: torch.Tensor, label_lengths: torch.Tensor, input_lengths: torch.Tensor, ): if ( acts is None or grads is None or costs is None or pad_labels is None or label_lengths is None or input_lengths is None ): return global_constants.RNNTStatus.RNNT_STATUS_INVALID_VALUE return self.compute_cost_and_score( acts, grads, costs, pad_labels, label_lengths, input_lengths )
[docs] def score_forward( self, acts: torch.Tensor, costs: torch.Tensor, pad_labels: torch.Tensor, label_lengths: torch.Tensor, input_lengths: torch.Tensor, ): if ( acts is None or costs is None or pad_labels is None or label_lengths is None or input_lengths is None ): return global_constants.RNNTStatus.RNNT_STATUS_INVALID_VALUE return self.compute_cost_and_score( acts, None, costs, pad_labels, label_lengths, input_lengths )
def _prepare_workspace(self) -> (int, Tuple[torch.Tensor]): """ Helper method that uses the workspace and constructs slices of it that can be used. Returns: An int, representing the offset of the used workspace (practically, the slice of the workspace consumed) A tuple of tensors representing the shared workspace. """ used_offset = 0 # // denom denom = self.gpu_workspace[ used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_ ] used_offset += self.maxT_ * self.maxU_ * self.minibatch_ # // alphas & betas alphas = self.gpu_workspace[ used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_ ] used_offset += self.maxT_ * self.maxU_ * self.minibatch_ betas = self.gpu_workspace[ used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_ ] used_offset += self.maxT_ * self.maxU_ * self.minibatch_ # // logllh llForward = self.gpu_workspace[used_offset : used_offset + self.minibatch_] used_offset += self.minibatch_ llBackward = self.gpu_workspace[used_offset : used_offset + self.minibatch_] used_offset += self.minibatch_ bigblank_durations = self.big_blank_workspace[: self.num_big_blanks] return used_offset, ( denom, alphas, betas, llForward, llBackward, bigblank_durations, )