#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2020, Technische Universität München;
# Dominik Winkelbauer, Ludwig Kürzinger
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
# Contribution: s920128 @ Github for prepare_tokenized_text()
"""CTC segmentation.
This file contains the core functions of CTC segmentation.
to extract utterance alignments within an audio file with
a given transcription.
For a description, see:
"CTC-Segmentation of Large Corpora for German End-to-end Speech Recognition"
https://arxiv.org/abs/2007.09127 or
https://link.springer.com/chapter/10.1007%2F978-3-030-60276-5_27
"""
import logging
import numpy as np
logger = logging.getLogger("ctc_segmentation")
# import for table of character probabilities mapped to time
try:
from .ctc_segmentation_dyn import cython_fill_table
except ImportError:
import pyximport
pyximport.install(setup_args={"include_dirs": np.get_include()})
from .ctc_segmentation_dyn import cython_fill_table
class CtcSegmentationParameters:
"""Default values for CTC segmentation.
May need adjustment according to localization or ASR settings.
The character set is taken from the model dict, i.e., usually are generated
with SentencePiece. An ASR model trained in the corresponding language and
character set is needed. If the character set contains any punctuation
characters, "#", the Greek char "ε", or the space placeholder, adapt
these settings.
"""
max_prob = -10000000000.0
skip_prob = -10000000000.0
min_window_size = 8000
max_window_size = 100000
index_duration = 0.025
score_min_mean_over_L = 30
space = "·"
blank = 0
replace_spaces_with_blanks = False
blank_transition_cost_zero = False
preamble_transition_cost_zero = True
backtrack_from_max_t = False
self_transition = "ε"
start_of_ground_truth = "#"
excluded_characters = ".,»«•❍·"
tokenized_meta_symbol = "▁"
char_list = None
# legacy Parameters (will be ignored in future versions)
subsampling_factor = None
frame_duration_ms = None
@property
def index_duration_in_seconds(self):
"""Derive index duration from frame duration and subsampling.
This value can be fixed by setting ctc_index_duration, which causes
frame_duration_ms and subsampling_factor to be ignored.
Legacy function. This function will be removed in later versions
and replaced by index_duration.
"""
if self.subsampling_factor and self.frame_duration_ms:
t = self.frame_duration_ms * self.subsampling_factor / 1000
else:
t = self.index_duration
return t
@property
def flags(self):
"""Get configuration flags to pass to the table_fill operation."""
flags = int(self.blank_transition_cost_zero)
flags += 2 * int(self.preamble_transition_cost_zero)
return flags
def update_excluded_characters(self):
"""Remove known tokens from the list of excluded characters."""
self.excluded_characters = "".join(
[
char
for char in self.excluded_characters
if True not in [char == j for j in self.char_list]
]
)
logger.debug(f"Excluded characters: {self.excluded_characters}")
def __init__(self, **kwargs):
"""Set all parameters as attribute at init."""
self.set(**kwargs)
def set(self, **kwargs):
"""Update CtcSegmentationParameters.
Args:
**kwargs: Key-value dict that contains all properties
with their new values. Unknown properties are ignored.
"""
for key in kwargs:
if (
not key.startswith("_")
and hasattr(self, key)
and kwargs[key] is not None
):
setattr(self, key, kwargs[key])
def __repr__(self):
"""Print all attribute as dictionary."""
output = "CtcSegmentationParameters( "
for attribute in self.__dict__.keys():
value = self.__dict__[attribute]
output += f"{attribute}={value}, "
output += ")"
return output
def ctc_segmentation(config, lpz, ground_truth):
"""Extract character-level utterance alignments.
:param config: an instance of CtcSegmentationParameters
:param lpz: probabilities obtained from CTC output
:param ground_truth: ground truth text in the form of a label sequence
:return:
"""
blank = config.blank
offset = 0
audio_duration = lpz.shape[0] * config.index_duration_in_seconds
logger.info(
f"CTC segmentation of {len(ground_truth)} chars "
f"to {audio_duration:.2f}s audio "
f"({lpz.shape[0]} indices)."
)
if len(ground_truth) > lpz.shape[0] and config.skip_prob <= config.max_prob:
raise AssertionError("Audio is shorter than text!")
window_size = config.min_window_size
# Try multiple window lengths if it fails
while True:
# Create table of alignment probabilities
table = np.zeros(
[min(window_size, lpz.shape[0]), len(ground_truth)], dtype=np.float32
)
table.fill(config.max_prob)
# Use array to log window offsets per character
offsets = np.zeros([len(ground_truth)], dtype=np.int64)
# Run actual alignment of utterances
t, c = cython_fill_table(
table,
lpz.astype(np.float32),
np.array(ground_truth, dtype=np.int64),
offsets,
config.blank,
config.flags,
)
if config.backtrack_from_max_t:
t = table.shape[0] - 1
logger.debug(
f"Max. joint probability to align text to audio: "
f"{table[:, c].max()} at time index {t}"
)
# Backtracking
timings = np.zeros([len(ground_truth)])
char_probs = np.zeros([lpz.shape[0]])
state_list = [""] * lpz.shape[0]
try:
# Do until start is reached
while t != 0 or c != 0:
# Calculate the possible transition probs towards the current cell
min_s = None
min_switch_prob_delta = np.inf
max_lpz_prob = config.max_prob
for s in range(ground_truth.shape[1]):
if ground_truth[c, s] != -1:
offset = offsets[c] - (offsets[c - 1 - s] if c - s > 0 else 0)
switch_prob = (
lpz[t + offsets[c], ground_truth[c, s]]
if c > 0
else config.max_prob
)
est_switch_prob = table[t, c] - table[t - 1 + offset, c - 1 - s]
if abs(switch_prob - est_switch_prob) < min_switch_prob_delta:
min_switch_prob_delta = abs(switch_prob - est_switch_prob)
min_s = s
max_lpz_prob = max(max_lpz_prob, switch_prob)
stay_prob = (
max(lpz[t + offsets[c], blank], max_lpz_prob)
if t > 0
else config.max_prob
)
est_stay_prob = table[t, c] - table[t - 1, c]
# Check which transition has been taken
if abs(stay_prob - est_stay_prob) > min_switch_prob_delta:
# Apply reverse switch transition
if c > 0:
# Log timing and character - frame alignment
for s in range(0, min_s + 1):
timings[c - s] = (
offsets[c] + t
) * config.index_duration_in_seconds
char_probs[offsets[c] + t] = max_lpz_prob
char_index = ground_truth[c, min_s]
state_list[offsets[c] + t] = config.char_list[char_index]
c -= 1 + min_s
t -= 1 - offset
else:
# Apply reverse stay transition
char_probs[offsets[c] + t] = stay_prob
state_list[offsets[c] + t] = config.self_transition
t -= 1
except IndexError:
logger.warning(
"IndexError: Backtracking was not successful, "
"the window size might be too small."
)
window_size *= 2
if window_size < config.max_window_size:
logger.warning("Increasing the window size to: " + str(window_size))
continue
else:
logger.error("Maximum window size reached.")
logger.error("Check data and character list!")
raise
break
return timings, char_probs, state_list
def prepare_text(config, text, char_list=None):
"""Prepare the given text for CTC segmentation.
Creates a matrix of character symbols to represent the given text,
then creates list of char indices depending on the models char list.
:param config: an instance of CtcSegmentationParameters
:param text: iterable of utterance transcriptions
:param char_list: a set or list that includes all characters/symbols,
characters not included in this list are ignored
:return: label matrix, character index matrix
"""
# temporary compatibility fix for previous espnet versions
if type(config.blank) == str:
config.blank = 0
if char_list is not None:
config.char_list = char_list
blank = config.char_list[config.blank]
ground_truth = config.start_of_ground_truth
utt_begin_indices = []
for utt in text:
# One space in-between
if not ground_truth.endswith(config.space):
ground_truth += config.space
# Start new utterance remember index
utt_begin_indices.append(len(ground_truth) - 1)
# Add chars of utterance
for char in utt:
if char.isspace() and config.replace_spaces_with_blanks:
if not ground_truth.endswith(config.space):
ground_truth += config.space
elif char in config.char_list and char not in config.excluded_characters:
ground_truth += char
# Add space to the end
if not ground_truth.endswith(config.space):
ground_truth += config.space
logger.debug(f"ground_truth: {ground_truth}")
utt_begin_indices.append(len(ground_truth) - 1)
# Create matrix: time frame x number of letters the character symbol spans
max_char_len = max([len(c) for c in config.char_list])
ground_truth_mat = np.ones([len(ground_truth), max_char_len], np.int64) * -1
for i in range(len(ground_truth)):
for s in range(max_char_len):
if i - s < 0:
continue
span = ground_truth[i - s : i + 1]
span = span.replace(config.space, blank)
if span in config.char_list:
char_index = config.char_list.index(span)
ground_truth_mat[i, s] = char_index
return ground_truth_mat, utt_begin_indices
def prepare_tokenized_text(config, text):
"""Prepare the given tokenized text for CTC segmentation.
:param config: an instance of CtcSegmentationParameters
:param text: string with tokens separated by spaces
:return: label matrix, character index matrix
"""
ground_truth = [config.start_of_ground_truth]
utt_begin_indices = []
for utt in text:
# One space in-between
if not ground_truth[-1] == config.space:
ground_truth += [config.space]
# Start new utterance remember index
utt_begin_indices.append(len(ground_truth) - 1)
# Add tokens of utterance
for token in utt.split():
if token in config.char_list:
if config.replace_spaces_with_blanks and not token.beginswith(
config.tokenized_meta_symbol
):
ground_truth += [config.space]
ground_truth += [token]
# Add space to the end
if not ground_truth[-1] == config.space:
ground_truth += [config.space]
logger.debug(f"ground_truth: {ground_truth}")
utt_begin_indices.append(len(ground_truth) - 1)
# Create matrix: time frame x number of letters the character symbol spans
max_char_len = 1
ground_truth_mat = np.ones([len(ground_truth), max_char_len], np.int64) * -1
for i in range(1, len(ground_truth)):
if ground_truth[i] == config.space:
ground_truth_mat[i, 0] = config.blank
else:
char_index = config.char_list.index(ground_truth[i])
ground_truth_mat[i, 0] = char_index
return ground_truth_mat, utt_begin_indices
def prepare_token_list(config, text):
"""Prepare the given token list for CTC segmentation.
This function expects the text input in form of a list
of numpy arrays: [np.array([2, 5]), np.array([7, 9])]
:param config: an instance of CtcSegmentationParameters
:param text: list of numpy arrays with tokens
:return: label matrix, character index matrix
"""
ground_truth = [-1]
utt_begin_indices = []
for utt in text:
# It's not possible to detect spaces when sequence is
# already tokenized, so we skip replace_spaces_with_blanks
# Insert blanks between utterances
if not ground_truth[-1] == config.blank:
ground_truth += [config.blank]
# Start-of-new-utterance remember index
utt_begin_indices.append(len(ground_truth) - 1)
# Append tokens to list
ground_truth += utt.tolist()
# Add a blank to the end
if not ground_truth[-1] == config.blank:
ground_truth += [config.blank]
logger.debug(f"ground_truth: {ground_truth}")
utt_begin_indices.append(len(ground_truth) - 1)
# Create matrix: time frame x number of letters the character symbol spans
ground_truth_mat = np.array(ground_truth, dtype=np.int64).reshape(-1, 1)
return ground_truth_mat, utt_begin_indices
def determine_utterance_segments(config, utt_begin_indices, char_probs, timings, text):
"""Utterance-wise alignments from char-wise alignments.
:param config: an instance of CtcSegmentationParameters
:param utt_begin_indices: list of time indices of utterance start
:param char_probs: character positioned probabilities obtained from backtracking
:param timings: mapping of time indices to seconds
:param text: list of utterances
:return: segments, a list of: utterance start and end [s], and its confidence score
"""
def compute_time(index, align_type):
"""Compute start and end time of utterance.
:param index: frame index value
:param align_type: one of ["begin", "end"]
:return: start/end time of utterance in seconds
"""
middle = (timings[index] + timings[index - 1]) / 2
if align_type == "begin":
return max(timings[index + 1] - 0.5, middle)
elif align_type == "end":
return min(timings[index - 1] + 0.5, middle)
segments = []
min_prob = np.float64(-10000000000.0)
for i in range(len(text)):
start = compute_time(utt_begin_indices[i], "begin")
end = compute_time(utt_begin_indices[i + 1], "end")
start_t = int(round(start / config.index_duration_in_seconds))
end_t = int(round(end / config.index_duration_in_seconds))
# Compute confidence score by using the min mean probability
# after splitting into segments of L frames
n = config.score_min_mean_over_L
if end_t <= start_t:
min_avg = min_prob
elif end_t - start_t <= n:
min_avg = char_probs[start_t:end_t].mean()
else:
min_avg = np.float64(0.0)
for t in range(start_t, end_t - n):
min_avg = min(min_avg, char_probs[t : t + n].mean())
segments.append((start, end, min_avg))
return segments