import argparse
from typing import Dict, Optional
import torch
from typeguard import check_argument_types
from espnet2.uasr.segmenter.abs_segmenter import AbsSegmenter
from espnet2.utils.types import str2bool
[docs]class JoinSegmenter(AbsSegmenter):
def __init__(
self,
cfg: Optional[Dict] = None,
subsample_rate: float = 0.25,
mean_pool: str2bool = True,
mean_join_pool: str2bool = False,
remove_zeros: str2bool = False,
):
super().__init__()
assert check_argument_types()
if cfg is not None:
cfg = argparse.Namespace(**cfg["segmentation"])
assert cfg.type == "JOIN"
self.subsampling_rate = cfg.subsample_rate
self.mean_pool = cfg.mean_pool
self.mean_pool_join = cfg.mean_pool_join
self.remove_zeros = cfg.remove_zeros
else:
self.mean_pool_join = mean_join_pool
self.remove_zeros = remove_zeros
[docs] def pre_segment(
self,
xs_pad: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
assert check_argument_types()
return xs_pad, padding_mask
[docs] def logit_segment(
self,
logits: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
assert check_argument_types()
preds = logits.argmax(dim=-1)
if padding_mask.any():
preds[padding_mask] = -1 # mark pad
uniques = []
batch_size, time_length, channel_size = logits.shape
for p in preds:
uniques.append(
p.cpu().unique_consecutive(return_inverse=True, return_counts=True)
)
new_time_length = max(u[0].numel() for u in uniques)
new_logits = logits.new_zeros(batch_size, new_time_length, channel_size)
new_pad = padding_mask.new_zeros(batch_size, new_time_length)
for b in range(batch_size):
value, index, count = uniques[b]
keep = value != -1
if self.remove_zeros:
keep.logical_and_(value != 0)
if self.training and not self.mean_pool_join:
value[0] = 0
value[1:] = count.cumsum(0)[:-1]
part = count > 1
random = torch.rand(part.sum())
value[part] += (count[part] * random).long()
new_logits[b, : value.numel()] = logits[b, value]
else:
new_logits[b].index_add_(
dim=0, index=index.to(new_logits.device), source=logits[b]
)
new_logits[b, : count.numel()] = new_logits[
b, : count.numel()
] / count.unsqueeze(-1).to(new_logits.device)
new_size = keep.sum()
if not keep.all():
kept_logits = new_logits[b, : count.numel()][keep]
new_logits[b, :new_size] = kept_logits
if new_size < new_time_length:
pad = new_time_length - new_size
new_logits[b, -pad:] = 0
new_pad[b, -pad:] = True
return new_logits, new_pad