# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright 2018-2019, Mingkun Huang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import torch
from numba import cuda
from espnet2.asr.transducer.rnnt_multi_blank.utils import global_constants, rnnt_helper
from espnet2.asr.transducer.rnnt_multi_blank.utils.cpu_utils import cpu_rnnt
from espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils import gpu_rnnt
[docs]def rnnt_loss_cpu(
acts: torch.Tensor,
labels: torch.Tensor,
input_lengths: torch.Tensor,
label_lengths: torch.Tensor,
costs: torch.Tensor,
grads: torch.Tensor,
blank_label: int,
fastemit_lambda: float,
clamp: float,
num_threads: int,
):
"""
Wrapper method for accessing CPU RNNT loss.
CPU implementation ported from [HawkAaron/warp-transducer]
(https://github.com/HawkAaron/warp-transducer).
Args:
acts: Activation tensor of shape [B, T, U, V+1].
labels: Ground truth labels of shape [B, U].
input_lengths: Lengths of the acoustic sequence as a vector of ints [B].
label_lengths: Lengths of the target sequence as a vector of ints [B].
costs: Zero vector of length [B] in which costs will be set.
grads: Zero tensor of shape [B, T, U, V+1] where the gradient will be set.
blank_label: Index of the blank token in the vocabulary.
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level
Emission Regularization.
clamp: Float value. When set to value >= 0.0, will clamp the
gradient to [-clamp, clamp].
num_threads: Number of threads for OpenMP.
"""
# aliases
log_probs = acts
flat_labels = labels
minibatch_size = log_probs.shape[0]
maxT = log_probs.shape[1]
maxU = log_probs.shape[2]
alphabet_size = log_probs.shape[3]
if num_threads < 0:
num_threads = multiprocessing.cpu_count()
num_threads = max(1, num_threads) # have to use at least 1 thread
gpu_size, status = rnnt_helper.get_workspace_size(
maxT, maxU, minibatch_size, gpu=False
)
if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS:
raise RuntimeError(
"Invalid parameter passed when calculating working space memory"
)
cpu_workspace = torch.zeros(
gpu_size, device=log_probs.device, dtype=log_probs.dtype, requires_grad=False
)
# VIEW TENSORS AS VECTORS FOR POINTER INDEXING
log_probs, acts_shape = rnnt_helper.flatten_tensor(log_probs)
flat_labels, labels_shape = rnnt_helper.flatten_tensor(flat_labels)
wrapper = cpu_rnnt.CPURNNT(
minibatch=minibatch_size,
maxT=maxT,
maxU=maxU,
alphabet_size=alphabet_size,
workspace=cpu_workspace,
blank=blank_label,
fastemit_lambda=fastemit_lambda,
clamp=clamp,
num_threads=num_threads,
batch_first=True,
)
if grads is None:
status = wrapper.score_forward(
log_probs=log_probs.data,
costs=costs,
flat_labels=flat_labels.data,
label_lengths=label_lengths.data,
input_lengths=input_lengths.data,
)
if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS:
raise RuntimeError("Could not calculate forward scores")
else:
# FLATTEN GRAD TENSOR
grads, grads_shape = rnnt_helper.flatten_tensor(grads)
status = wrapper.cost_and_grad(
log_probs=log_probs.data,
grads=grads.data,
costs=costs,
flat_labels=flat_labels.data,
label_lengths=label_lengths.data,
input_lengths=input_lengths.data,
)
if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS:
raise RuntimeError("Could not calculate forward scores")
del cpu_workspace, wrapper
return True
[docs]def rnnt_loss_gpu(
acts: torch.Tensor,
labels: torch.Tensor,
input_lengths: torch.Tensor,
label_lengths: torch.Tensor,
costs: torch.Tensor,
grads: torch.Tensor,
blank_label: int,
fastemit_lambda: float,
clamp: float,
num_threads: int,
):
"""
Wrapper method for accessing GPU RNNT loss.
CUDA implementation ported from [HawkAaron/warp-transducer]
(https://github.com/HawkAaron/warp-transducer).
Args:
acts: Activation tensor of shape [B, T, U, V+1].
labels: Ground truth labels of shape [B, U].
input_lengths: Lengths of the acoustic sequence as a vector of ints [B].
label_lengths: Lengths of the target sequence as a vector of ints [B].
costs: Zero vector of length [B] in which costs will be set.
grads: Zero tensor of shape [B, T, U, V+1] where the gradient will be set.
blank_label: Index of the blank token in the vocabulary.
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level
Emission Regularization.
clamp: Float value. When set to value >= 0.0, will clamp the
gradient to [-clamp, clamp].
num_threads: Number of threads for OpenMP.
"""
minibatch_size = acts.shape[0]
maxT = acts.shape[1]
maxU = acts.shape[2]
alphabet_size = acts.shape[3]
if hasattr(cuda, "external_stream"):
stream = cuda.external_stream(
torch.cuda.current_stream(acts.device).cuda_stream
)
else:
stream = cuda.default_stream()
if num_threads < 0:
num_threads = multiprocessing.cpu_count()
num_threads = max(1, num_threads) # have to use at least 1 thread
gpu_size, status = rnnt_helper.get_workspace_size(
maxT, maxU, minibatch_size, gpu=True
)
if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS:
raise RuntimeError(
"Invalid parameter passed when calculating working space memory"
)
# Select GPU index
cuda.select_device(acts.device.index)
gpu_workspace = torch.zeros(
gpu_size, device=acts.device, dtype=acts.dtype, requires_grad=False
)
# VIEW TENSORS AS VECTORS FOR POINTER INDEXING
acts, acts_shape = rnnt_helper.flatten_tensor(acts)
wrapper = gpu_rnnt.GPURNNT(
minibatch=minibatch_size,
maxT=maxT,
maxU=maxU,
alphabet_size=alphabet_size,
workspace=gpu_workspace,
blank=blank_label,
fastemit_lambda=fastemit_lambda,
clamp=clamp,
num_threads=num_threads,
stream=stream,
)
if grads is None:
status = wrapper.score_forward(
acts=acts.data,
costs=costs.data,
pad_labels=labels.data,
label_lengths=label_lengths.data,
input_lengths=input_lengths.data,
)
if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS:
raise RuntimeError("Could not calculate forward scores")
else:
# FLATTEN GRAD TENSOR
grads, grads_shape = rnnt_helper.flatten_tensor(grads)
status = wrapper.cost_and_grad(
acts=acts.data,
grads=grads.data,
costs=costs.data,
pad_labels=labels.data,
label_lengths=label_lengths.data,
input_lengths=input_lengths.data,
)
if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS:
raise RuntimeError("Could not calculate forward scores")
del gpu_workspace, wrapper
return True
[docs]def multiblank_rnnt_loss_gpu(
acts: torch.Tensor,
labels: torch.Tensor,
input_lengths: torch.Tensor,
label_lengths: torch.Tensor,
costs: torch.Tensor,
grads: torch.Tensor,
blank_label: int,
big_blank_durations: list,
fastemit_lambda: float,
clamp: float,
num_threads: int,
sigma: float,
):
"""
Wrapper method for accessing GPU Multi-blank RNNT loss
(https://arxiv.org/pdf/2211.03541.pdf).
CUDA implementation ported from [HawkAaron/warp-transducer]
(https://github.com/HawkAaron/warp-transducer).
Args:
acts: Activation tensor of shape [B, T, U, V + num_big_blanks + 1].
labels: Ground truth labels of shape [B, U].
input_lengths: Lengths of the acoustic sequence as a vector of ints [B].
label_lengths: Lengths of the target sequence as a vector of ints [B].
costs: Zero vector of length [B] in which costs will be set.
grads: Zero tensor of shape [B, T, U, V + num_big_blanks + 1]
where the gradient will be set.
blank_label: Index of the standard blank token in the vocabulary.
big_blank_durations: A list of supported durations for big blank symbols
in the model, e.g. [2, 4, 8]. Note we only include durations for ``big
blanks'' here and it should not include 1 for the standard blank.
Those big blanks have vocabulary indices after the standard blank index.
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level
Emission Regularization.
clamp: Float value. When set to value >= 0.0, will clamp the
gradient to [-clamp, clamp].
num_threads: Number of threads for OpenMP.
sigma: logit-undernormalization weight used in the multi-blank model. Refer to
the multi-blank paper https://arxiv.org/pdf/2211.03541
for detailed explanations.
"""
minibatch_size = acts.shape[0]
maxT = acts.shape[1]
maxU = acts.shape[2]
alphabet_size = acts.shape[3]
if hasattr(cuda, "external_stream"):
stream = cuda.external_stream(
torch.cuda.current_stream(acts.device).cuda_stream
)
else:
stream = cuda.default_stream()
if num_threads < 0:
num_threads = multiprocessing.cpu_count()
num_threads = max(1, num_threads) # have to use at least 1 thread
gpu_size, status = rnnt_helper.get_workspace_size(
maxT, maxU, minibatch_size, gpu=True
)
if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS:
raise RuntimeError(
"Invalid parameter passed when calculating working space memory"
)
# Select GPU index
cuda.select_device(acts.device.index)
gpu_workspace = torch.zeros(
gpu_size, device=acts.device, dtype=acts.dtype, requires_grad=False
)
big_blank_workspace = torch.zeros(
len(big_blank_durations),
device=acts.device,
dtype=torch.long,
requires_grad=False,
)
for i in range(0, len(big_blank_durations)):
big_blank_workspace[i] = big_blank_durations[i]
# VIEW TENSORS AS VECTORS FOR POINTER INDEXING
acts, acts_shape = rnnt_helper.flatten_tensor(acts)
wrapper = gpu_rnnt.MultiblankGPURNNT(
minibatch=minibatch_size,
maxT=maxT,
maxU=maxU,
alphabet_size=alphabet_size,
workspace=gpu_workspace,
big_blank_workspace=big_blank_workspace,
num_big_blanks=len(big_blank_durations),
blank=blank_label,
fastemit_lambda=fastemit_lambda,
clamp=clamp,
num_threads=num_threads,
stream=stream,
sigma=sigma,
)
if grads is None:
status = wrapper.score_forward(
acts=acts.data,
costs=costs.data,
pad_labels=labels.data,
label_lengths=label_lengths.data,
input_lengths=input_lengths.data,
)
if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS:
raise RuntimeError("Could not calculate forward scores")
else:
# FLATTEN GRAD TENSOR
grads, grads_shape = rnnt_helper.flatten_tensor(grads)
status = wrapper.cost_and_grad(
acts=acts.data,
grads=grads.data,
costs=costs.data,
pad_labels=labels.data,
label_lengths=label_lengths.data,
input_lengths=input_lengths.data,
)
if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS:
raise RuntimeError("Could not calculate forward scores")
del gpu_workspace, big_blank_workspace, wrapper
return True