Source code for espnet2.enh.layers.adapt_layers

# noqa E501: Ported from https://github.com/BUTSpeechFIT/speakerbeam/blob/main/src/models/adapt_layers.py
# Copyright (c) 2021 Brno University of Technology
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
# All rights reserved
# By Katerina Zmolikova, August 2021.

from functools import partial

import torch
import torch.nn as nn


[docs]def make_adapt_layer(type, indim, enrolldim, ninputs=1): adapt_class = adaptation_layer_types.get(type) return adapt_class(indim, enrolldim, ninputs)
[docs]def into_tuple(x): """Transforms tensor/list/tuple into tuple.""" if isinstance(x, list): return tuple(x) elif isinstance(x, torch.Tensor): return (x,) elif isinstance(x, tuple): return x else: raise ValueError("x should be tensor, list of tuple")
[docs]def into_orig_type(x, orig_type): """Inverts into_tuple function.""" if orig_type is tuple: return x if orig_type is list: return list(x) if orig_type is torch.Tensor: return x[0] else: assert False
[docs]class ConcatAdaptLayer(nn.Module): def __init__(self, indim, enrolldim, ninputs=1): super().__init__() self.ninputs = ninputs self.transform = nn.ModuleList( [nn.Linear(indim + enrolldim, indim) for _ in range(ninputs)] )
[docs] def forward(self, main, enroll): """ConcatAdaptLayer forward. Args: main: tensor or tuple or list activations in the main neural network, which are adapted tuple/list may be useful when we want to apply the adaptation to both normal and skip connection at once enroll: tensor or tuple or list embedding extracted from enrollment tuple/list may be useful when we want to apply the adaptation to both normal and skip connection at once """ assert type(main) == type(enroll) orig_type = type(main) main, enroll = into_tuple(main), into_tuple(enroll) assert len(main) == len(enroll) == self.ninputs out = [] for transform, main0, enroll0 in zip(self.transform, main, enroll): out.append( transform( torch.cat( (main0, enroll0[:, :, None].expand(main0.shape)), dim=1 ).permute(0, 2, 1) ).permute(0, 2, 1) ) return into_orig_type(tuple(out), orig_type)
[docs]class MulAddAdaptLayer(nn.Module): def __init__(self, indim, enrolldim, ninputs=1, do_addition=True): super().__init__() self.ninputs = ninputs self.do_addition = do_addition if do_addition: assert enrolldim == 2 * indim, (enrolldim, indim) else: assert enrolldim == indim, (enrolldim, indim)
[docs] def forward(self, main, enroll): """MulAddAdaptLayer Forward. Args: main: tensor or tuple or list activations in the main neural network, which are adapted tuple/list may be useful when we want to apply the adaptation to both normal and skip connection at once enroll: tensor or tuple or list embedding extracted from enrollment tuple/list may be useful when we want to apply the adaptation to both normal and skip connection at once """ assert type(main) == type(enroll) orig_type = type(main) main, enroll = into_tuple(main), into_tuple(enroll) assert len(main) == len(enroll) == self.ninputs, ( len(main), len(enroll), self.ninputs, ) out = [] for main0, enroll0 in zip(main, enroll): if self.do_addition: enroll0_mul, enroll0_add = torch.chunk(enroll0, 2, dim=1) out.append(enroll0_mul[:, :, None] * main0 + enroll0_add[:, :, None]) else: out.append(enroll0[:, :, None] * main0) return into_orig_type(tuple(out), orig_type)
# aliases for possible adaptation layer types adaptation_layer_types = { "concat": ConcatAdaptLayer, "muladd": MulAddAdaptLayer, "mul": partial(MulAddAdaptLayer, do_addition=False), }