import difflib
import torch
[docs]def get_layer(l_name, library=torch.nn):
"""Return layer object handler from library e.g. from torch.nn
E.g. if l_name=="elu", returns torch.nn.ELU.
Args:
l_name (string): Case insensitive name for layer in library (e.g. .'elu').
library (module): Name of library/module where to search for object handler
with l_name e.g. "torch.nn".
Returns:
layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU)
"""
all_torch_layers = [x for x in dir(torch.nn)]
match = [x for x in all_torch_layers if l_name.lower() == x.lower()]
if len(match) == 0:
close_matches = difflib.get_close_matches(
l_name, [x.lower() for x in all_torch_layers]
)
raise NotImplementedError(
"Layer with name {} not found in {}.\n Closest matches: {}".format(
l_name, str(library), close_matches
)
)
elif len(match) > 1:
close_matches = difflib.get_close_matches(
l_name, [x.lower() for x in all_torch_layers]
)
raise NotImplementedError(
"Multiple matchs for layer with name {} not found in {}.\n "
"All matches: {}".format(l_name, str(library), close_matches)
)
else:
# valid
layer_handler = getattr(library, match[0])
return layer_handler