Spaces:
Runtime error
Runtime error
| """ Attention Factory | |
| Hacked together by / Copyright 2021 Ross Wightman | |
| """ | |
| import torch | |
| from functools import partial | |
| from .bottleneck_attn import BottleneckAttn | |
| from .cbam import CbamModule, LightCbamModule | |
| from .eca import EcaModule, CecaModule | |
| from .gather_excite import GatherExcite | |
| from .global_context import GlobalContext | |
| from .halo_attn import HaloAttn | |
| from .involution import Involution | |
| from .lambda_layer import LambdaLayer | |
| from .non_local_attn import NonLocalAttn, BatNonLocalAttn | |
| from .selective_kernel import SelectiveKernel | |
| from .split_attn import SplitAttn | |
| from .squeeze_excite import SEModule, EffectiveSEModule | |
| from .swin_attn import WindowAttention | |
| def get_attn(attn_type): | |
| if isinstance(attn_type, torch.nn.Module): | |
| return attn_type | |
| module_cls = None | |
| if attn_type is not None: | |
| if isinstance(attn_type, str): | |
| attn_type = attn_type.lower() | |
| # Lightweight attention modules (channel and/or coarse spatial). | |
| # Typically added to existing network architecture blocks in addition to existing convolutions. | |
| if attn_type == 'se': | |
| module_cls = SEModule | |
| elif attn_type == 'ese': | |
| module_cls = EffectiveSEModule | |
| elif attn_type == 'eca': | |
| module_cls = EcaModule | |
| elif attn_type == 'ecam': | |
| module_cls = partial(EcaModule, use_mlp=True) | |
| elif attn_type == 'ceca': | |
| module_cls = CecaModule | |
| elif attn_type == 'ge': | |
| module_cls = GatherExcite | |
| elif attn_type == 'gc': | |
| module_cls = GlobalContext | |
| elif attn_type == 'cbam': | |
| module_cls = CbamModule | |
| elif attn_type == 'lcbam': | |
| module_cls = LightCbamModule | |
| # Attention / attention-like modules w/ significant params | |
| # Typically replace some of the existing workhorse convs in a network architecture. | |
| # All of these accept a stride argument and can spatially downsample the input. | |
| elif attn_type == 'sk': | |
| module_cls = SelectiveKernel | |
| elif attn_type == 'splat': | |
| module_cls = SplitAttn | |
| # Self-attention / attention-like modules w/ significant compute and/or params | |
| # Typically replace some of the existing workhorse convs in a network architecture. | |
| # All of these accept a stride argument and can spatially downsample the input. | |
| elif attn_type == 'lambda': | |
| return LambdaLayer | |
| elif attn_type == 'bottleneck': | |
| return BottleneckAttn | |
| elif attn_type == 'halo': | |
| return HaloAttn | |
| elif attn_type == 'swin': | |
| return WindowAttention | |
| elif attn_type == 'involution': | |
| return Involution | |
| elif attn_type == 'nl': | |
| module_cls = NonLocalAttn | |
| elif attn_type == 'bat': | |
| module_cls = BatNonLocalAttn | |
| # Woops! | |
| else: | |
| assert False, "Invalid attn module (%s)" % attn_type | |
| elif isinstance(attn_type, bool): | |
| if attn_type: | |
| module_cls = SEModule | |
| else: | |
| module_cls = attn_type | |
| return module_cls | |
| def create_attn(attn_type, channels, **kwargs): | |
| module_cls = get_attn(attn_type) | |
| if module_cls is not None: | |
| # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels | |
| return module_cls(channels, **kwargs) | |
| return None | |