-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
52 changed files
with
5,503 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from .activations import * | ||
from .adaptive_avgmax_pool import \ | ||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d | ||
from .blur_pool import BlurPool2d | ||
from .classifier import ClassifierHead, create_classifier | ||
from .cond_conv2d import CondConv2d, get_condconv_initializer | ||
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ | ||
set_layer_config | ||
from .conv2d_same import Conv2dSame, conv2d_same | ||
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct | ||
from .create_act import create_act_layer, get_act_layer, get_act_fn | ||
from .create_attn import get_attn, create_attn | ||
from .create_conv2d import create_conv2d | ||
from .create_norm import get_norm_layer, create_norm_layer | ||
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer | ||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path | ||
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn | ||
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ | ||
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a | ||
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm | ||
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d | ||
from .gather_excite import GatherExcite | ||
from .global_context import GlobalContext | ||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple | ||
from .inplace_abn import InplaceAbn | ||
from .linear import Linear | ||
from .mixed_conv2d import MixedConv2d | ||
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp | ||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn | ||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d | ||
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm | ||
from .padding import get_padding, get_same_padding, pad_same | ||
from .patch_embed import PatchEmbed | ||
from .pool2d_same import AvgPool2dSame, create_pool2d | ||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite | ||
from .selective_kernel import SelectiveKernel | ||
from .separable_conv import SeparableConv2d, SeparableConvNormAct | ||
from .space_to_depth import SpaceToDepthModule | ||
from .split_attn import SplitAttn | ||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model | ||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame | ||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool | ||
from .trace_utils import _assert, _float_to_int | ||
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
""" Activations | ||
A collection of activations fn and modules with a common interface so that they can | ||
easily be swapped. All have an `inplace` arg even if not used. | ||
Hacked together by / Copyright 2020 Ross Wightman | ||
""" | ||
|
||
import torch | ||
from torch import nn as nn | ||
from torch.nn import functional as F | ||
|
||
|
||
def swish(x, inplace: bool = False): | ||
"""Swish - Described in: https://arxiv.org/abs/1710.05941 | ||
""" | ||
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) | ||
|
||
|
||
class Swish(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(Swish, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return swish(x, self.inplace) | ||
|
||
|
||
def mish(x, inplace: bool = False): | ||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 | ||
NOTE: I don't have a working inplace variant | ||
""" | ||
return x.mul(F.softplus(x).tanh()) | ||
|
||
|
||
class Mish(nn.Module): | ||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 | ||
""" | ||
def __init__(self, inplace: bool = False): | ||
super(Mish, self).__init__() | ||
|
||
def forward(self, x): | ||
return mish(x) | ||
|
||
|
||
def sigmoid(x, inplace: bool = False): | ||
return x.sigmoid_() if inplace else x.sigmoid() | ||
|
||
|
||
# PyTorch has this, but not with a consistent inplace argmument interface | ||
class Sigmoid(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(Sigmoid, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return x.sigmoid_() if self.inplace else x.sigmoid() | ||
|
||
|
||
def tanh(x, inplace: bool = False): | ||
return x.tanh_() if inplace else x.tanh() | ||
|
||
|
||
# PyTorch has this, but not with a consistent inplace argmument interface | ||
class Tanh(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(Tanh, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return x.tanh_() if self.inplace else x.tanh() | ||
|
||
|
||
def hard_swish(x, inplace: bool = False): | ||
inner = F.relu6(x + 3.).div_(6.) | ||
return x.mul_(inner) if inplace else x.mul(inner) | ||
|
||
|
||
class HardSwish(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(HardSwish, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return hard_swish(x, self.inplace) | ||
|
||
|
||
def hard_sigmoid(x, inplace: bool = False): | ||
if inplace: | ||
return x.add_(3.).clamp_(0., 6.).div_(6.) | ||
else: | ||
return F.relu6(x + 3.) / 6. | ||
|
||
|
||
class HardSigmoid(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(HardSigmoid, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return hard_sigmoid(x, self.inplace) | ||
|
||
|
||
def hard_mish(x, inplace: bool = False): | ||
""" Hard Mish | ||
Experimental, based on notes by Mish author Diganta Misra at | ||
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md | ||
""" | ||
if inplace: | ||
return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) | ||
else: | ||
return 0.5 * x * (x + 2).clamp(min=0, max=2) | ||
|
||
|
||
class HardMish(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(HardMish, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return hard_mish(x, self.inplace) | ||
|
||
|
||
class PReLU(nn.PReLU): | ||
"""Applies PReLU (w/ dummy inplace arg) | ||
""" | ||
def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: | ||
super(PReLU, self).__init__(num_parameters=num_parameters, init=init) | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
return F.prelu(input, self.weight) | ||
|
||
|
||
def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: | ||
return F.gelu(x) | ||
|
||
|
||
class GELU(nn.Module): | ||
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) | ||
""" | ||
def __init__(self, inplace: bool = False): | ||
super(GELU, self).__init__() | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
return F.gelu(input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
""" Activations | ||
A collection of jit-scripted activations fn and modules with a common interface so that they can | ||
easily be swapped. All have an `inplace` arg even if not used. | ||
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not | ||
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted | ||
versions if they contain in-place ops. | ||
Hacked together by / Copyright 2020 Ross Wightman | ||
""" | ||
|
||
import torch | ||
from torch import nn as nn | ||
from torch.nn import functional as F | ||
|
||
|
||
@torch.jit.script | ||
def swish_jit(x, inplace: bool = False): | ||
"""Swish - Described in: https://arxiv.org/abs/1710.05941 | ||
""" | ||
return x.mul(x.sigmoid()) | ||
|
||
|
||
@torch.jit.script | ||
def mish_jit(x, _inplace: bool = False): | ||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 | ||
""" | ||
return x.mul(F.softplus(x).tanh()) | ||
|
||
|
||
class SwishJit(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(SwishJit, self).__init__() | ||
|
||
def forward(self, x): | ||
return swish_jit(x) | ||
|
||
|
||
class MishJit(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(MishJit, self).__init__() | ||
|
||
def forward(self, x): | ||
return mish_jit(x) | ||
|
||
|
||
@torch.jit.script | ||
def hard_sigmoid_jit(x, inplace: bool = False): | ||
# return F.relu6(x + 3.) / 6. | ||
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? | ||
|
||
|
||
class HardSigmoidJit(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(HardSigmoidJit, self).__init__() | ||
|
||
def forward(self, x): | ||
return hard_sigmoid_jit(x) | ||
|
||
|
||
@torch.jit.script | ||
def hard_swish_jit(x, inplace: bool = False): | ||
# return x * (F.relu6(x + 3.) / 6) | ||
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? | ||
|
||
|
||
class HardSwishJit(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(HardSwishJit, self).__init__() | ||
|
||
def forward(self, x): | ||
return hard_swish_jit(x) | ||
|
||
|
||
@torch.jit.script | ||
def hard_mish_jit(x, inplace: bool = False): | ||
""" Hard Mish | ||
Experimental, based on notes by Mish author Diganta Misra at | ||
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md | ||
""" | ||
return 0.5 * x * (x + 2).clamp(min=0, max=2) | ||
|
||
|
||
class HardMishJit(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(HardMishJit, self).__init__() | ||
|
||
def forward(self, x): | ||
return hard_mish_jit(x) |
Oops, something went wrong.