diff --git a/lib/models_timm/layers/__init__.py b/lib/models_timm/layers/__init__.py new file mode 100644 index 0000000..21c641b --- /dev/null +++ b/lib/models_timm/layers/__init__.py @@ -0,0 +1,44 @@ +from .activations import * +from .adaptive_avgmax_pool import \ + adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .blur_pool import BlurPool2d +from .classifier import ClassifierHead, create_classifier +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ + set_layer_config +from .conv2d_same import Conv2dSame, conv2d_same +from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct +from .create_act import create_act_layer, get_act_layer, get_act_fn +from .create_attn import get_attn, create_attn +from .create_conv2d import create_conv2d +from .create_norm import get_norm_layer, create_norm_layer +from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ + EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm +from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple +from .inplace_abn import InplaceAbn +from .linear import Linear +from .mixed_conv2d import MixedConv2d +from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d +from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm +from .padding import get_padding, get_same_padding, pad_same +from .patch_embed import PatchEmbed +from .pool2d_same import AvgPool2dSame, create_pool2d +from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from .selective_kernel import SelectiveKernel +from .separable_conv import SeparableConv2d, SeparableConvNormAct +from .space_to_depth import SpaceToDepthModule +from .split_attn import SplitAttn +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .trace_utils import _assert, _float_to_int +from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/lib/models_timm/layers/activations.py b/lib/models_timm/layers/activations.py new file mode 100644 index 0000000..e16b3bd --- /dev/null +++ b/lib/models_timm/layers/activations.py @@ -0,0 +1,145 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +def swish(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + NOTE: I don't have a working inplace variant + """ + return x.mul(F.softplus(x).tanh()) + + +class Mish(nn.Module): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + + def forward(self, x): + return mish(x) + + +def sigmoid(x, inplace: bool = False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace: bool = False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace: bool = False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class HardSigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + + +def hard_mish(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + if inplace: + return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) + else: + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_mish(x, self.inplace) + + +class PReLU(nn.PReLU): + """Applies PReLU (w/ dummy inplace arg) + """ + def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: + super(PReLU, self).__init__(num_parameters=num_parameters, init=init) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.prelu(input, self.weight) + + +def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return F.gelu(x) + + +class GELU(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) + """ + def __init__(self, inplace: bool = False): + super(GELU, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input) diff --git a/lib/models_timm/layers/activations_jit.py b/lib/models_timm/layers/activations_jit.py new file mode 100644 index 0000000..b4a5165 --- /dev/null +++ b/lib/models_timm/layers/activations_jit.py @@ -0,0 +1,90 @@ +""" Activations + +A collection of jit-scripted activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not +currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted +versions if they contain in-place ops. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul(x.sigmoid()) + + +@torch.jit.script +def mish_jit(x, _inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + + def forward(self, x): + return swish_jit(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + + def forward(self, x): + return mish_jit(x) + + +@torch.jit.script +def hard_sigmoid_jit(x, inplace: bool = False): + # return F.relu6(x + 3.) / 6. + return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSigmoidJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidJit, self).__init__() + + def forward(self, x): + return hard_sigmoid_jit(x) + + +@torch.jit.script +def hard_swish_jit(x, inplace: bool = False): + # return x * (F.relu6(x + 3.) / 6) + return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishJit, self).__init__() + + def forward(self, x): + return hard_swish_jit(x) + + +@torch.jit.script +def hard_mish_jit(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishJit, self).__init__() + + def forward(self, x): + return hard_mish_jit(x) diff --git a/lib/models_timm/layers/activations_me.py b/lib/models_timm/layers/activations_me.py new file mode 100644 index 0000000..9a12bb7 --- /dev/null +++ b/lib/models_timm/layers/activations_me.py @@ -0,0 +1,218 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use either +the JIT or basic versions of the activations. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + """ + @staticmethod + def symbolic(g, x): + return g.op("Mul", x, g.op("Sigmoid", x)) + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishJitAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient, jit scripted variant of Mish + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishJitAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_sigmoid_jit_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_sigmoid_jit_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. + return grad_output * m + + +class HardSigmoidJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_jit_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidJitAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidJitAutoFn.apply(x) + + +@torch.jit.script +def hard_swish_jit_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_swish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.) + m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) + return grad_output * m + + +class HardSwishJitAutoFn(torch.autograd.Function): + """A memory efficient, jit-scripted HardSwish activation""" + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_jit_bwd(x, grad_output) + + @staticmethod + def symbolic(g, self): + input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float))) + hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + return g.op("Mul", self, hardtanh_) + + +def hard_swish_me(x, inplace=False): + return HardSwishJitAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_mish_jit_fwd(x): + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +@torch.jit.script +def hard_mish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= -2.) + m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) + return grad_output * m + + +class HardMishJitAutoFn(torch.autograd.Function): + """ A memory efficient, jit scripted variant of Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_mish_jit_bwd(x, grad_output) + + +def hard_mish_me(x, inplace: bool = False): + return HardMishJitAutoFn.apply(x) + + +class HardMishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishMe, self).__init__() + + def forward(self, x): + return HardMishJitAutoFn.apply(x) + + + diff --git a/lib/models_timm/layers/adaptive_avgmax_pool.py b/lib/models_timm/layers/adaptive_avgmax_pool.py new file mode 100644 index 0000000..ebc6ada --- /dev/null +++ b/lib/models_timm/layers/adaptive_avgmax_pool.py @@ -0,0 +1,118 @@ +""" PyTorch selectable adaptive pooling +Adaptive pooling with the ability to select the type of pooling from: + * 'avg' - Average pooling + * 'max' - Max pooling + * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 + * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim + +Both a functional and a nn.Module version of the pooling is provided. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def adaptive_pool_feat_mult(pool_type='avg'): + if pool_type == 'catavgmax': + return 2 + else: + return 1 + + +def adaptive_avgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return 0.5 * (x_avg + x_max) + + +def adaptive_catavgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return torch.cat((x_avg, x_max), 1) + + +def select_adaptive_pool2d(x, pool_type='avg', output_size=1): + """Selectable global pooling function with dynamic input kernel size + """ + if pool_type == 'avg': + x = F.adaptive_avg_pool2d(x, output_size) + elif pool_type == 'avgmax': + x = adaptive_avgmax_pool2d(x, output_size) + elif pool_type == 'catavgmax': + x = adaptive_catavgmax_pool2d(x, output_size) + elif pool_type == 'max': + x = F.adaptive_max_pool2d(x, output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + return x + + +class FastAdaptiveAvgPool2d(nn.Module): + def __init__(self, flatten=False): + super(FastAdaptiveAvgPool2d, self).__init__() + self.flatten = flatten + + def forward(self, x): + return x.mean((2, 3), keepdim=not self.flatten) + + +class AdaptiveAvgMaxPool2d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_avgmax_pool2d(x, self.output_size) + + +class AdaptiveCatAvgMaxPool2d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveCatAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_catavgmax_pool2d(x, self.output_size) + + +class SelectAdaptivePool2d(nn.Module): + """Selectable global pooling layer with dynamic input kernel size + """ + def __init__(self, output_size=1, pool_type='fast', flatten=False): + super(SelectAdaptivePool2d, self).__init__() + self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing + self.flatten = nn.Flatten(1) if flatten else nn.Identity() + if pool_type == '': + self.pool = nn.Identity() # pass through + elif pool_type == 'fast': + assert output_size == 1 + self.pool = FastAdaptiveAvgPool2d(flatten) + self.flatten = nn.Identity() + elif pool_type == 'avg': + self.pool = nn.AdaptiveAvgPool2d(output_size) + elif pool_type == 'avgmax': + self.pool = AdaptiveAvgMaxPool2d(output_size) + elif pool_type == 'catavgmax': + self.pool = AdaptiveCatAvgMaxPool2d(output_size) + elif pool_type == 'max': + self.pool = nn.AdaptiveMaxPool2d(output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + + def is_identity(self): + return not self.pool_type + + def forward(self, x): + x = self.pool(x) + x = self.flatten(x) + return x + + def feat_mult(self): + return adaptive_pool_feat_mult(self.pool_type) + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + 'pool_type=' + self.pool_type \ + + ', flatten=' + str(self.flatten) + ')' + diff --git a/lib/models_timm/layers/attention_pool2d.py b/lib/models_timm/layers/attention_pool2d.py new file mode 100644 index 0000000..a13a688 --- /dev/null +++ b/lib/models_timm/layers/attention_pool2d.py @@ -0,0 +1,131 @@ +""" Attention Pool 2D + +Implementations of 2D spatial feature pooling using multi-head attention instead of average pool. + +Based on idea in CLIP by OpenAI, licensed Apache 2.0 +https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import Union, Tuple + +import torch +import torch.nn as nn + +from .helpers import to_2tuple +from .pos_embed import apply_rot_embed, RotaryEmbedding +from .weight_init import trunc_normal_ + + +class RotAttentionPool2d(nn.Module): + """ Attention based 2D feature pooling w/ rotary (relative) pos embedding. + This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. + + Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. + https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + + NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from + train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW + """ + def __init__( + self, + in_features: int, + out_features: int = None, + embed_dim: int = None, + num_heads: int = 4, + qkv_bias: bool = True, + ): + super().__init__() + embed_dim = embed_dim or in_features + out_features = out_features or in_features + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, out_features) + self.num_heads = num_heads + assert embed_dim % num_heads == 0 + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + self.pos_embed = RotaryEmbedding(self.head_dim) + + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) + nn.init.zeros_(self.qkv.bias) + + def forward(self, x): + B, _, H, W = x.shape + N = H * W + x = x.reshape(B, -1, N).permute(0, 2, 1) + + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + + x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = x[0], x[1], x[2] + + qc, q = q[:, :, :1], q[:, :, 1:] + sin_emb, cos_emb = self.pos_embed.get_embed((H, W)) + q = apply_rot_embed(q, sin_emb, cos_emb) + q = torch.cat([qc, q], dim=2) + + kc, k = k[:, :, :1], k[:, :, 1:] + k = apply_rot_embed(k, sin_emb, cos_emb) + k = torch.cat([kc, k], dim=2) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + x = self.proj(x) + return x[:, 0] + + +class AttentionPool2d(nn.Module): + """ Attention based 2D feature pooling w/ learned (absolute) pos embedding. + This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. + + It was based on impl in CLIP by OpenAI + https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + + NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. + """ + def __init__( + self, + in_features: int, + feat_size: Union[int, Tuple[int, int]], + out_features: int = None, + embed_dim: int = None, + num_heads: int = 4, + qkv_bias: bool = True, + ): + super().__init__() + + embed_dim = embed_dim or in_features + out_features = out_features or in_features + assert embed_dim % num_heads == 0 + self.feat_size = to_2tuple(feat_size) + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, out_features) + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + + spatial_dim = self.feat_size[0] * self.feat_size[1] + self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) + trunc_normal_(self.pos_embed, std=in_features ** -0.5) + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) + nn.init.zeros_(self.qkv.bias) + + def forward(self, x): + B, _, H, W = x.shape + N = H * W + assert self.feat_size[0] == H + assert self.feat_size[1] == W + x = x.reshape(B, -1, N).permute(0, 2, 1) + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + + x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = x[0], x[1], x[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + x = self.proj(x) + return x[:, 0] diff --git a/lib/models_timm/layers/blur_pool.py b/lib/models_timm/layers/blur_pool.py new file mode 100644 index 0000000..e73d886 --- /dev/null +++ b/lib/models_timm/layers/blur_pool.py @@ -0,0 +1,42 @@ +""" +BlurPool layer inspired by + - Kornia's Max_BlurPool2d + - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` + +Hacked together by Chris Ha and Ross Wightman +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .padding import get_padding + + +class BlurPool2d(nn.Module): + r"""Creates a module that computes blurs and downsample a given feature map. + See :cite:`zhang2019shiftinvar` for more details. + Corresponds to the Downsample class, which does blurring and subsampling + + Args: + channels = Number of input channels + filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. + stride (int): downsampling filter stride + + Returns: + torch.Tensor: the transformed tensor. + """ + def __init__(self, channels, filt_size=3, stride=2) -> None: + super(BlurPool2d, self).__init__() + assert filt_size > 1 + self.channels = channels + self.filt_size = filt_size + self.stride = stride + self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 + coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) + blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) + self.register_buffer('filt', blur_filter, persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding, 'reflect') + return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) diff --git a/lib/models_timm/layers/bottleneck_attn.py b/lib/models_timm/layers/bottleneck_attn.py new file mode 100644 index 0000000..c3db464 --- /dev/null +++ b/lib/models_timm/layers/bottleneck_attn.py @@ -0,0 +1,157 @@ +""" Bottleneck Self Attention (Bottleneck Transformers) + +Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + +@misc{2101.11605, +Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani}, +Title = {Bottleneck Transformers for Visual Recognition}, +Year = {2021}, +} + +Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + +This impl is a WIP but given that it is based on the ref gist likely not too far off. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import to_2tuple, make_divisible +from .weight_init import trunc_normal_ +from .trace_utils import _assert + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, heads, height, width, dim) + rel_k: (2 * width - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, 2 * W -1) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, W - 1]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1) + x = x_pad[:, :W, W - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + """ + def __init__(self, feat_size, dim_head, scale): + super().__init__() + self.height, self.width = to_2tuple(feat_size) + self.dim_head = dim_head + self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale) + self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale) + + def forward(self, q): + B, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(B, self.height, self.width, -1) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, HW, HW) + return rel_logits + + +class BottleneckAttn(nn.Module): + """ Bottleneck Attention + Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + + The internal dimensions of the attention module are controlled by the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query and key (qk) dimensions are determined by + * num_heads * dim_head if dim_head is not None + * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + stride (int): output stride of the module, avg pool used if stride == 2 (default: 1). + num_heads (int): parallel attention heads (default: 4) + dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool): add bias to q, k, and v projections + scale_pos_embed (bool): scale the position embedding as well as Q @ K + """ + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None, + qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False): + super().__init__() + assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + self.num_heads = num_heads + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v + self.scale = self.dim_head_qk ** -0.5 + self.scale_pos_embed = scale_pos_embed + + self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) + + # NOTE I'm only supporting relative pos embedding for now + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + self.reset_parameters() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def forward(self, x): + B, C, H, W = x.shape + _assert(H == self.pos_embed.height, '') + _assert(W == self.pos_embed.width, '') + + x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W + + # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v + # So, this is more verbose than if heads were before qkv splits, but throughput is not impacted. + q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1) + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2) + k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k + v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) + + if self.scale_pos_embed: + attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W + else: + attn = (q @ k) * self.scale + self.pos_embed(q) + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W + out = self.pool(out) + return out diff --git a/lib/models_timm/layers/cbam.py b/lib/models_timm/layers/cbam.py new file mode 100644 index 0000000..576a830 --- /dev/null +++ b/lib/models_timm/layers/cbam.py @@ -0,0 +1,112 @@ +""" CBAM (sort-of) Attention + +Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 + +WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on +some tasks, especially fine-grained it seems. I may end up removing this impl. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import nn as nn +import torch.nn.functional as F + +from .conv_bn_act import ConvNormAct +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible + + +class ChannelAttn(nn.Module): + """ Original CBAM channel attention module, currently avg + max pool variant only. + """ + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(ChannelAttn, self).__init__() + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) + x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) + return x * self.gate(x_avg + x_max) + + +class LightChannelAttn(ChannelAttn): + """An experimental 'lightweight' that sums avg + max pool first + """ + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(LightChannelAttn, self).__init__( + channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) + + def forward(self, x): + x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) + x_attn = self.fc2(self.act(self.fc1(x_pool))) + return x * F.sigmoid(x_attn) + + +class SpatialAttn(nn.Module): + """ Original CBAM spatial attention module + """ + def __init__(self, kernel_size=7, gate_layer='sigmoid'): + super(SpatialAttn, self).__init__() + self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1) + x_attn = self.conv(x_attn) + return x * self.gate(x_attn) + + +class LightSpatialAttn(nn.Module): + """An experimental 'lightweight' variant that sums avg_pool and max_pool results. + """ + def __init__(self, kernel_size=7, gate_layer='sigmoid'): + super(LightSpatialAttn, self).__init__() + self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) + x_attn = self.conv(x_attn) + return x * self.gate(x_attn) + + +class CbamModule(nn.Module): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(CbamModule, self).__init__() + self.channel = ChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) + self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + + +class LightCbamModule(nn.Module): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(LightCbamModule, self).__init__() + self.channel = LightChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) + self.spatial = LightSpatialAttn(spatial_kernel_size) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + diff --git a/lib/models_timm/layers/classifier.py b/lib/models_timm/layers/classifier.py new file mode 100644 index 0000000..3ac3338 --- /dev/null +++ b/lib/models_timm/layers/classifier.py @@ -0,0 +1,56 @@ +""" Classifier head and layer factory + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn +from torch.nn import functional as F + +from .adaptive_avgmax_pool import SelectAdaptivePool2d + + +def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): + flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling + if not pool_type: + assert num_classes == 0 or use_conv,\ + 'Pooling can only be disabled if classifier is also removed or conv classifier is used' + flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) + global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) + num_pooled_features = num_features * global_pool.feat_mult() + return global_pool, num_pooled_features + + +def _create_fc(num_features, num_classes, use_conv=False): + if num_classes <= 0: + fc = nn.Identity() # pass-through (no classifier) + elif use_conv: + fc = nn.Conv2d(num_features, num_classes, 1, bias=True) + else: + fc = nn.Linear(num_features, num_classes, bias=True) + return fc + + +def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): + global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) + fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + return global_pool, fc + + +class ClassifierHead(nn.Module): + """Classifier head w/ configurable global pooling and dropout.""" + + def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): + super(ClassifierHead, self).__init__() + self.drop_rate = drop_rate + self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) + self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() + + def forward(self, x, pre_logits: bool = False): + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + if pre_logits: + return x.flatten(1) + else: + x = self.fc(x) + return self.flatten(x) diff --git a/lib/models_timm/layers/cond_conv2d.py b/lib/models_timm/layers/cond_conv2d.py new file mode 100644 index 0000000..43654c5 --- /dev/null +++ b/lib/models_timm/layers/cond_conv2d.py @@ -0,0 +1,123 @@ +""" PyTorch Conditionally Parameterized Convolution (CondConv) + +Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference +(https://arxiv.org/abs/1904.04971) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import math +from functools import partial +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .helpers import to_2tuple +from .conv2d_same import conv2d_same +from .padding import get_padding_value + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditionally Parameterized Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = to_2tuple(padding_val) + self.dilation = to_2tuple(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + # reshape instead of view to work with channels_last input + x = x.reshape(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out diff --git a/lib/models_timm/layers/config.py b/lib/models_timm/layers/config.py new file mode 100644 index 0000000..f07b9d7 --- /dev/null +++ b/lib/models_timm/layers/config.py @@ -0,0 +1,115 @@ +""" Model / Layer Config singleton state +""" +from typing import Any, Optional + +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False diff --git a/lib/models_timm/layers/conv2d_same.py b/lib/models_timm/layers/conv2d_same.py new file mode 100644 index 0000000..75f0f98 --- /dev/null +++ b/lib/models_timm/layers/conv2d_same.py @@ -0,0 +1,42 @@ +""" Conv2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional + +from .padding import pad_same, get_padding_value + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + x = pad_same(x, weight.shape[-2:], stride, dilation) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + diff --git a/lib/models_timm/layers/conv_bn_act.py b/lib/models_timm/layers/conv_bn_act.py new file mode 100644 index 0000000..9e7c64b --- /dev/null +++ b/lib/models_timm/layers/conv_bn_act.py @@ -0,0 +1,88 @@ +""" Conv2d + BN + Act + +Hacked together by / Copyright 2020 Ross Wightman +""" +import functools +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import get_norm_act_layer + + +class ConvNormAct(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, + bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None): + super(ConvNormAct, self).__init__() + self.conv = create_conv2d( + in_channels, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` + norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} + self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) + + @property + def in_channels(self): + return self.conv.in_channels + + @property + def out_channels(self): + return self.conv.out_channels + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +ConvBnAct = ConvNormAct + + +def create_aa(aa_layer, channels, stride=2, enable=True): + if not aa_layer or not enable: + return nn.Identity() + if isinstance(aa_layer, functools.partial): + if issubclass(aa_layer.func, nn.AvgPool2d): + return aa_layer() + else: + return aa_layer(channels) + elif issubclass(aa_layer, nn.AvgPool2d): + return aa_layer(stride) + else: + return aa_layer(channels=channels, stride=stride) + + +class ConvNormActAa(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, + bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): + super(ConvNormActAa, self).__init__() + use_aa = aa_layer is not None and stride == 2 + + self.conv = create_conv2d( + in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` + norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} + self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) + self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) + + @property + def in_channels(self): + return self.conv.in_channels + + @property + def out_channels(self): + return self.conv.out_channels + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.aa(x) + return x diff --git a/lib/models_timm/layers/create_act.py b/lib/models_timm/layers/create_act.py new file mode 100644 index 0000000..a3044a3 --- /dev/null +++ b/lib/models_timm/layers/create_act.py @@ -0,0 +1,154 @@ +""" Activation Factory +Hacked together by / Copyright 2020 Ross Wightman +""" +from typing import Union, Callable, Type + +from .activations import * +from .activations_jit import * +from .activations_me import * +from .config import is_exportable, is_scriptable, is_no_jit + +# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. +# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. +# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. +_has_silu = 'silu' in dir(torch.nn.functional) +_has_hardswish = 'hardswish' in dir(torch.nn.functional) +_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional) +_has_mish = 'mish' in dir(torch.nn.functional) + + +_ACT_FN_DEFAULT = dict( + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, + mish=F.mish if _has_mish else mish, + relu=F.relu, + relu6=F.relu6, + leaky_relu=F.leaky_relu, + elu=F.elu, + celu=F.celu, + selu=F.selu, + gelu=gelu, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, + hard_swish=F.hardswish if _has_hardswish else hard_swish, + hard_mish=hard_mish, +) + +_ACT_FN_JIT = dict( + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, + mish=F.mish if _has_mish else mish_jit, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, + hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, + hard_mish=hard_mish_jit +) + +_ACT_FN_ME = dict( + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, + mish=F.mish if _has_mish else mish_me, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, + hard_swish=F.hardswish if _has_hardswish else hard_swish_me, + hard_mish=hard_mish_me, +) + +_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) +for a in _ACT_FNS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + +_ACT_LAYER_DEFAULT = dict( + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, + mish=nn.Mish if _has_mish else Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + leaky_relu=nn.LeakyReLU, + elu=nn.ELU, + prelu=PReLU, + celu=nn.CELU, + selu=nn.SELU, + gelu=GELU, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, + hard_swish=nn.Hardswish if _has_hardswish else HardSwish, + hard_mish=HardMish, +) + +_ACT_LAYER_JIT = dict( + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, + mish=nn.Mish if _has_mish else MishJit, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, + hard_mish=HardMishJit +) + +_ACT_LAYER_ME = dict( + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, + mish=nn.Mish if _has_mish else MishMe, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, + hard_mish=HardMishMe, +) + +_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) +for a in _ACT_LAYERS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + +def get_act_fn(name: Union[Callable, str] = 'relu'): + """ Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if isinstance(name, Callable): + return name + if not (is_no_jit() or is_exportable() or is_scriptable()): + # If not exporting or scripting the model, first look for a memory-efficient version with + # custom autograd, then fallback + if name in _ACT_FN_ME: + return _ACT_FN_ME[name] + if not (is_no_jit() or is_exportable()): + if name in _ACT_FN_JIT: + return _ACT_FN_JIT[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): + """ Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if not isinstance(name, str): + # callable, module, etc + return name + if not (is_no_jit() or is_exportable() or is_scriptable()): + if name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + if not (is_no_jit() or is_exportable()): + if name in _ACT_LAYER_JIT: + return _ACT_LAYER_JIT[name] + return _ACT_LAYER_DEFAULT[name] + + +def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): + act_layer = get_act_layer(name) + if act_layer is None: + return None + if inplace is None: + return act_layer(**kwargs) + try: + return act_layer(inplace=inplace, **kwargs) + except TypeError: + # recover if act layer doesn't have inplace arg + return act_layer(**kwargs) diff --git a/lib/models_timm/layers/create_attn.py b/lib/models_timm/layers/create_attn.py new file mode 100644 index 0000000..cc7e91e --- /dev/null +++ b/lib/models_timm/layers/create_attn.py @@ -0,0 +1,89 @@ +""" Attention Factory + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +from functools import partial + +from .bottleneck_attn import BottleneckAttn +from .cbam import CbamModule, LightCbamModule +from .eca import EcaModule, CecaModule +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .halo_attn import HaloAttn +from .lambda_layer import LambdaLayer +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .selective_kernel import SelectiveKernel +from .split_attn import SplitAttn +from .squeeze_excite import SEModule, EffectiveSEModule + + +def get_attn(attn_type): + if isinstance(attn_type, torch.nn.Module): + return attn_type + module_cls = None + if attn_type: + if isinstance(attn_type, str): + attn_type = attn_type.lower() + # Lightweight attention modules (channel and/or coarse spatial). + # Typically added to existing network architecture blocks in addition to existing convolutions. + if attn_type == 'se': + module_cls = SEModule + elif attn_type == 'ese': + module_cls = EffectiveSEModule + elif attn_type == 'eca': + module_cls = EcaModule + elif attn_type == 'ecam': + module_cls = partial(EcaModule, use_mlp=True) + elif attn_type == 'ceca': + module_cls = CecaModule + elif attn_type == 'ge': + module_cls = GatherExcite + elif attn_type == 'gc': + module_cls = GlobalContext + elif attn_type == 'gca': + module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) + elif attn_type == 'cbam': + module_cls = CbamModule + elif attn_type == 'lcbam': + module_cls = LightCbamModule + + # Attention / attention-like modules w/ significant params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'sk': + module_cls = SelectiveKernel + elif attn_type == 'splat': + module_cls = SplitAttn + + # Self-attention / attention-like modules w/ significant compute and/or params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'lambda': + return LambdaLayer + elif attn_type == 'bottleneck': + return BottleneckAttn + elif attn_type == 'halo': + return HaloAttn + elif attn_type == 'nl': + module_cls = NonLocalAttn + elif attn_type == 'bat': + module_cls = BatNonLocalAttn + + # Woops! + else: + assert False, "Invalid attn module (%s)" % attn_type + elif isinstance(attn_type, bool): + if attn_type: + module_cls = SEModule + else: + module_cls = attn_type + return module_cls + + +def create_attn(attn_type, channels, **kwargs): + module_cls = get_attn(attn_type) + if module_cls is not None: + # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels + return module_cls(channels, **kwargs) + return None diff --git a/lib/models_timm/layers/create_conv2d.py b/lib/models_timm/layers/create_conv2d.py new file mode 100644 index 0000000..ac9489c --- /dev/null +++ b/lib/models_timm/layers/create_conv2d.py @@ -0,0 +1,36 @@ +""" Create Conv2d Factory Method + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d +from .conv2d_same import create_conv2d_pad + + +def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): + """ Select a 2d convolution implementation based on arguments + Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. + + Used extensively by EfficientNet, MobileNetv3 and related networks. + """ + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + if 'groups' in kwargs: + groups = kwargs.pop('groups') + if groups == in_channels: + kwargs['depthwise'] = True + else: + assert groups == 1 + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 + groups = in_channels if depthwise else kwargs.pop('groups', 1) + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) + return m diff --git a/lib/models_timm/layers/create_norm.py b/lib/models_timm/layers/create_norm.py new file mode 100644 index 0000000..b9efae8 --- /dev/null +++ b/lib/models_timm/layers/create_norm.py @@ -0,0 +1,56 @@ +""" Norm Layer Factory + +Create norm modules by string (to mirror create_act and creat_norm-act fns) + +Copyright 2022 Ross Wightman +""" +import types +import functools + +import torch.nn as nn + +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d + +_NORM_MAP = dict( + batchnorm=nn.BatchNorm2d, + batchnorm2d=nn.BatchNorm2d, + batchnorm1d=nn.BatchNorm1d, + groupnorm=GroupNorm, + groupnorm1=GroupNorm1, + layernorm=LayerNorm, + layernorm2d=LayerNorm2d, +) +_NORM_TYPES = {m for n, m in _NORM_MAP.items()} + + +def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs): + layer = get_norm_layer(layer_name, act_layer=act_layer) + layer_instance = layer(num_features, apply_act=apply_act, **kwargs) + return layer_instance + + +def get_norm_layer(norm_layer): + assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) + norm_kwargs = {} + + # unbind partial fn, so args can be rebound later + if isinstance(norm_layer, functools.partial): + norm_kwargs.update(norm_layer.keywords) + norm_layer = norm_layer.func + + if isinstance(norm_layer, str): + layer_name = norm_layer.replace('_', '') + norm_layer = _NORM_MAP.get(layer_name, None) + elif norm_layer in _NORM_TYPES: + norm_layer = norm_layer + elif isinstance(norm_layer, types.FunctionType): + # if function type, assume it is a lambda/fn that creates a norm layer + norm_layer = norm_layer + else: + type_name = norm_layer.__name__.lower().replace('_', '') + norm_layer = _NORM_MAP.get(type_name, None) + assert norm_layer is not None, f"No equivalent norm layer for {type_name}" + + if norm_kwargs: + norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args + return norm_layer diff --git a/lib/models_timm/layers/create_norm_act.py b/lib/models_timm/layers/create_norm_act.py new file mode 100644 index 0000000..78dd9a5 --- /dev/null +++ b/lib/models_timm/layers/create_norm_act.py @@ -0,0 +1,91 @@ +""" NormAct (Normalizaiton + Activation Layer) Factory + +Create norm + act combo modules that attempt to be backwards compatible with separate norm + act +isntances in models. Where these are used it will be possible to swap separate BN + act layers with +combined modules like IABN or EvoNorms. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import types +import functools + +from .evo_norm import * +from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d +from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d +from .inplace_abn import InplaceAbn + +_NORM_ACT_MAP = dict( + batchnorm=BatchNormAct2d, + batchnorm2d=BatchNormAct2d, + groupnorm=GroupNormAct, + groupnorm1=functools.partial(GroupNormAct, num_groups=1), + layernorm=LayerNormAct, + layernorm2d=LayerNormAct2d, + evonormb0=EvoNorm2dB0, + evonormb1=EvoNorm2dB1, + evonormb2=EvoNorm2dB2, + evonorms0=EvoNorm2dS0, + evonorms0a=EvoNorm2dS0a, + evonorms1=EvoNorm2dS1, + evonorms1a=EvoNorm2dS1a, + evonorms2=EvoNorm2dS2, + evonorms2a=EvoNorm2dS2a, + frn=FilterResponseNormAct2d, + frntlu=FilterResponseNormTlu2d, + inplaceabn=InplaceAbn, + iabn=InplaceAbn, +) +_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} +# has act_layer arg to define act type +_NORM_ACT_REQUIRES_ARG = { + BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn} + + +def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): + layer = get_norm_act_layer(layer_name, act_layer=act_layer) + layer_instance = layer(num_features, apply_act=apply_act, **kwargs) + if jit: + layer_instance = torch.jit.script(layer_instance) + return layer_instance + + +def get_norm_act_layer(norm_layer, act_layer=None): + assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) + assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) + norm_act_kwargs = {} + + # unbind partial fn, so args can be rebound later + if isinstance(norm_layer, functools.partial): + norm_act_kwargs.update(norm_layer.keywords) + norm_layer = norm_layer.func + + if isinstance(norm_layer, str): + layer_name = norm_layer.replace('_', '').lower().split('-')[0] + norm_act_layer = _NORM_ACT_MAP.get(layer_name, None) + elif norm_layer in _NORM_ACT_TYPES: + norm_act_layer = norm_layer + elif isinstance(norm_layer, types.FunctionType): + # if function type, must be a lambda/fn that creates a norm_act layer + norm_act_layer = norm_layer + else: + type_name = norm_layer.__name__.lower() + if type_name.startswith('batchnorm'): + norm_act_layer = BatchNormAct2d + elif type_name.startswith('groupnorm'): + norm_act_layer = GroupNormAct + elif type_name.startswith('groupnorm1'): + norm_act_layer = functools.partial(GroupNormAct, num_groups=1) + elif type_name.startswith('layernorm2d'): + norm_act_layer = LayerNormAct2d + elif type_name.startswith('layernorm'): + norm_act_layer = LayerNormAct + else: + assert False, f"No equivalent norm_act layer for {type_name}" + + if norm_act_layer in _NORM_ACT_REQUIRES_ARG: + # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. + # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types + norm_act_kwargs.setdefault('act_layer', act_layer) + if norm_act_kwargs: + norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args + return norm_act_layer diff --git a/lib/models_timm/layers/drop.py b/lib/models_timm/layers/drop.py new file mode 100644 index 0000000..1ab1c8f --- /dev/null +++ b/lib/models_timm/layers/drop.py @@ -0,0 +1,169 @@ +""" DropBlock, DropPath + +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. + +Papers: +DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) + +Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) + +Code: +DropBlock impl inspired by two Tensorflow impl that I liked: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def drop_block_2d( + x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, + with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. This layer has been tested on a few training + runs with success, but needs further validation and possibly optimization for lower runtime impact. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + # seed_drop_rate, the gamma parameter + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) + + if batchwise: + # one mask for whole batch, quite a bit faster + uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) + else: + uniform_noise = torch.rand_like(x) + block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + else: + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +def drop_block_fast_2d( + x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, + gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid + block mask at edges. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + block_mask = torch.empty_like(x).bernoulli_(gamma) + block_mask = F.max_pool2d( + block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.empty_like(x).normal_() + if inplace: + x.mul_(1. - block_mask).add_(normal_noise * block_mask) + else: + x = x * (1. - block_mask) + normal_noise * block_mask + else: + block_mask = 1 - block_mask + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + + def __init__( + self, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, + batchwise: bool = False, + fast: bool = True): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + self.inplace = inplace + self.batchwise = batchwise + self.fast = fast # FIXME finish comparisons of fast vs not + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + if self.fast: + return drop_block_fast_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace) + else: + return drop_block_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' diff --git a/lib/models_timm/layers/eca.py b/lib/models_timm/layers/eca.py new file mode 100644 index 0000000..e29be6a --- /dev/null +++ b/lib/models_timm/layers/eca.py @@ -0,0 +1,145 @@ +""" +ECA module from ECAnet + +paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks +https://arxiv.org/abs/1910.03151 + +Original ECA model borrowed from https://github.com/BangguWu/ECANet + +Modified circular ECA implementation and adaption for use in timm package +by Chris Ha https://github.com/VRandme + +Original License: + +MIT License + +Copyright (c) 2019 BangguWu, Qilong Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import math +from torch import nn +import torch.nn.functional as F + + +from .create_act import create_act_layer +from .helpers import make_divisible + + +class EcaModule(nn.Module): + """Constructs an ECA module. + + Args: + channels: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + kernel_size: Adaptive selection of kernel size (default=3) + gamm: used in kernel_size calc, see above + beta: used in kernel_size calc, see above + act_layer: optional non-linearity after conv, enables conv bias, this is an experiment + gate_layer: gating non-linearity to use + """ + def __init__( + self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid', + rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False): + super(EcaModule, self).__init__() + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) + assert kernel_size % 2 == 1 + padding = (kernel_size - 1) // 2 + if use_mlp: + # NOTE 'mlp' mode is a timm experiment, not in paper + assert channels is not None + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor) + act_layer = act_layer or nn.ReLU + self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True) + self.act = create_act_layer(act_layer) + self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True) + else: + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.act = None + self.conv2 = None + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv + y = self.conv(y) + if self.conv2 is not None: + y = self.act(y) + y = self.conv2(y) + y = self.gate(y).view(x.shape[0], -1, 1, 1) + return x * y.expand_as(x) + + +EfficientChannelAttn = EcaModule # alias + + +class CecaModule(nn.Module): + """Constructs a circular ECA module. + + ECA module where the conv uses circular padding rather than zero padding. + Unlike the spatial dimension, the channels do not have inherent ordering nor + locality. Although this module in essence, applies such an assumption, it is unnecessary + to limit the channels on either "edge" from being circularly adapted to each other. + This will fundamentally increase connectivity and possibly increase performance metrics + (accuracy, robustness), without significantly impacting resource metrics + (parameter size, throughput,latency, etc) + + Args: + channels: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + kernel_size: Adaptive selection of kernel size (default=3) + gamm: used in kernel_size calc, see above + beta: used in kernel_size calc, see above + act_layer: optional non-linearity after conv, enables conv bias, this is an experiment + gate_layer: gating non-linearity to use + """ + + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): + super(CecaModule, self).__init__() + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) + has_act = act_layer is not None + assert kernel_size % 2 == 1 + + # PyTorch circular padding mode is buggy as of pytorch 1.4 + # see https://github.com/pytorch/pytorch/pull/17240 + # implement manual circular padding + self.padding = (kernel_size - 1) // 2 + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + y = x.mean((2, 3)).view(x.shape[0], 1, -1) + # Manually implement circular padding, F.pad does not seemed to be bugged + y = F.pad(y, (self.padding, self.padding), mode='circular') + y = self.conv(y) + y = self.gate(y).view(x.shape[0], -1, 1, 1) + return x * y.expand_as(x) + + +CircularEfficientChannelAttn = CecaModule diff --git a/lib/models_timm/layers/evo_norm.py b/lib/models_timm/layers/evo_norm.py new file mode 100644 index 0000000..ea77620 --- /dev/null +++ b/lib/models_timm/layers/evo_norm.py @@ -0,0 +1,352 @@ +""" EvoNorm in PyTorch + +Based on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967 +@inproceedings{NEURIPS2020, + author = {Liu, Hanxiao and Brock, Andy and Simonyan, Karen and Le, Quoc}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin}, + pages = {13539--13550}, + publisher = {Curran Associates, Inc.}, + title = {Evolving Normalization-Activation Layers}, + url = {https://proceedings.neurips.cc/paper/2020/file/9d4c03631b8b0c85ae08bf05eda37d0f-Paper.pdf}, + volume = {33}, + year = {2020} +} + +An attempt at getting decent performing EvoNorms running in PyTorch. +While faster than other PyTorch impl, still quite a ways off the built-in BatchNorm +in terms of memory usage and throughput on GPUs. + +I'm testing these modules on TPU w/ PyTorch XLA. Promising start but +currently working around some issues with builtin torch/tensor.var/std. Unlike +GPU, similar train speeds for EvoNormS variants and BatchNorm. + +Hacked together by / Copyright 2020 Ross Wightman +""" +from typing import Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer +from .trace_utils import _assert + + +def instance_std(x, eps: float = 1e-5): + std = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype) + return std.expand(x.shape) + + +def instance_std_tpu(x, eps: float = 1e-5): + std = manual_var(x, dim=(2, 3)).add(eps).sqrt() + return std.expand(x.shape) +# instance_std = instance_std_tpu + + +def instance_rms(x, eps: float = 1e-5): + rms = x.float().square().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(x.dtype) + return rms.expand(x.shape) + + +def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False): + xm = x.mean(dim=dim, keepdim=True) + if diff_sqm: + # difference of squared mean and mean squared, faster on TPU can be less stable + var = ((x * x).mean(dim=dim, keepdim=True) - (xm * xm)).clamp(0) + else: + var = ((x - xm) * (x - xm)).mean(dim=dim, keepdim=True) + return var + + +def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False): + B, C, H, W = x.shape + x_dtype = x.dtype + _assert(C % groups == 0, '') + if flatten: + x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues + std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) + else: + x = x.reshape(B, groups, C // groups, H, W) + std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) + return std.expand(x.shape).reshape(B, C, H, W) + + +def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False): + # This is a workaround for some stability / odd behaviour of .var and .std + # running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results + B, C, H, W = x.shape + _assert(C % groups == 0, '') + if flatten: + x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues + var = manual_var(x, dim=-1, diff_sqm=diff_sqm) + else: + x = x.reshape(B, groups, C // groups, H, W) + var = manual_var(x, dim=(2, 3, 4), diff_sqm=diff_sqm) + return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W) +#group_std = group_std_tpu # FIXME TPU temporary + + +def group_rms(x, groups: int = 32, eps: float = 1e-5): + B, C, H, W = x.shape + _assert(C % groups == 0, '') + x_dtype = x.dtype + x = x.reshape(B, groups, C // groups, H, W) + rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(x_dtype) + return rms.expand(x.shape).reshape(B, C, H, W) + + +class EvoNorm2dB0(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-3, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.momentum = momentum + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.v is not None: + nn.init.ones_(self.v) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.v is not None: + if self.training: + var = x.float().var(dim=(0, 2, 3), unbiased=False) + # var = manual_var(x, dim=(0, 2, 3)).squeeze() + n = x.numel() / x.shape[1] + self.running_var.copy_( + self.running_var * (1 - self.momentum) + + var.detach() * self.momentum * (n / (n - 1))) + else: + var = self.running_var + left = var.add(self.eps).sqrt_().to(x_dtype).view(v_shape).expand_as(x) + v = self.v.to(x_dtype).view(v_shape) + right = x * v + instance_std(x, self.eps) + x = x / left.max(right) + return x * self.weight.to(x_dtype).view(v_shape) + self.bias.to(x_dtype).view(v_shape) + + +class EvoNorm2dB1(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.momentum = momentum + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.apply_act: + if self.training: + var = x.float().var(dim=(0, 2, 3), unbiased=False) + n = x.numel() / x.shape[1] + self.running_var.copy_( + self.running_var * (1 - self.momentum) + + var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1))) + else: + var = self.running_var + var = var.to(x_dtype).view(v_shape) + left = var.add(self.eps).sqrt_() + right = (x + 1) * instance_rms(x, self.eps) + x = x / left.max(right) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dB2(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.momentum = momentum + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.apply_act: + if self.training: + var = x.float().var(dim=(0, 2, 3), unbiased=False) + n = x.numel() / x.shape[1] + self.running_var.copy_( + self.running_var * (1 - self.momentum) + + var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1))) + else: + var = self.running_var + var = var.to(x_dtype).view(v_shape) + left = var.add(self.eps).sqrt_() + right = instance_rms(x, self.eps) - x + x = x / left.max(right) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dS0(nn.Module): + def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + if group_size: + assert num_features % group_size == 0 + self.groups = num_features // group_size + else: + self.groups = groups + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.v is not None: + nn.init.ones_(self.v) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.v is not None: + v = self.v.view(v_shape).to(x_dtype) + x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dS0a(EvoNorm2dS0): + def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-3, **_): + super().__init__( + num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + d = group_std(x, self.groups, self.eps) + if self.v is not None: + v = self.v.view(v_shape).to(x_dtype) + x = x * (x * v).sigmoid() + x = x / d + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dS1(nn.Module): + def __init__( + self, num_features, groups=32, group_size=None, + apply_act=True, act_layer=None, eps=1e-5, **_): + super().__init__() + act_layer = act_layer or nn.SiLU + self.apply_act = apply_act # apply activation (non-linearity) + if act_layer is not None and apply_act: + self.act = create_act_layer(act_layer) + else: + self.act = nn.Identity() + if group_size: + assert num_features % group_size == 0 + self.groups = num_features // group_size + else: + self.groups = groups + self.eps = eps + self.pre_act_norm = False + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.apply_act: + x = self.act(x) / group_std(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dS1a(EvoNorm2dS1): + def __init__( + self, num_features, groups=32, group_size=None, + apply_act=True, act_layer=None, eps=1e-3, **_): + super().__init__( + num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + x = self.act(x) / group_std(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dS2(nn.Module): + def __init__( + self, num_features, groups=32, group_size=None, + apply_act=True, act_layer=None, eps=1e-5, **_): + super().__init__() + act_layer = act_layer or nn.SiLU + self.apply_act = apply_act # apply activation (non-linearity) + if act_layer is not None and apply_act: + self.act = create_act_layer(act_layer) + else: + self.act = nn.Identity() + if group_size: + assert num_features % group_size == 0 + self.groups = num_features // group_size + else: + self.groups = groups + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.apply_act: + x = self.act(x) / group_rms(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dS2a(EvoNorm2dS2): + def __init__( + self, num_features, groups=32, group_size=None, + apply_act=True, act_layer=None, eps=1e-3, **_): + super().__init__( + num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + x = self.act(x) / group_rms(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) diff --git a/lib/models_timm/layers/fast_norm.py b/lib/models_timm/layers/fast_norm.py new file mode 100644 index 0000000..fb35e47 --- /dev/null +++ b/lib/models_timm/layers/fast_norm.py @@ -0,0 +1,78 @@ +""" 'Fast' Normalization Functions + +For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32. + +Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast) + +Hacked together by / Copyright 2022 Ross Wightman +""" +from typing import List, Optional + +import torch +from torch.nn import functional as F + +try: + from apex.normalization.fused_layer_norm import fused_layer_norm_affine + has_apex = True +except ImportError: + has_apex = False + + +# fast (ie lower precision LN) can be disabled with this flag if issues crop up +_USE_FAST_NORM = False # defaulting to False for now + + +def is_fast_norm(): + return _USE_FAST_NORM + + +def set_fast_norm(enable=True): + global _USE_FAST_NORM + _USE_FAST_NORM = enable + + +def fast_group_norm( + x: torch.Tensor, + num_groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5 +) -> torch.Tensor: + if torch.jit.is_scripting(): + # currently cannot use is_autocast_enabled within torchscript + return F.group_norm(x, num_groups, weight, bias, eps) + + if torch.is_autocast_enabled(): + # normally native AMP casts GN inputs to float32 + # here we use the low precision autocast dtype + # FIXME what to do re CPU autocast? + dt = torch.get_autocast_gpu_dtype() + x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) + + with torch.cuda.amp.autocast(enabled=False): + return F.group_norm(x, num_groups, weight, bias, eps) + + +def fast_layer_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5 +) -> torch.Tensor: + if torch.jit.is_scripting(): + # currently cannot use is_autocast_enabled within torchscript + return F.layer_norm(x, normalized_shape, weight, bias, eps) + + if has_apex: + return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) + + if torch.is_autocast_enabled(): + # normally native AMP casts LN inputs to float32 + # apex LN does not, this is behaving like Apex + dt = torch.get_autocast_gpu_dtype() + # FIXME what to do re CPU autocast? + x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) + + with torch.cuda.amp.autocast(enabled=False): + return F.layer_norm(x, normalized_shape, weight, bias, eps) diff --git a/lib/models_timm/layers/filter_response_norm.py b/lib/models_timm/layers/filter_response_norm.py new file mode 100644 index 0000000..a66a1cd --- /dev/null +++ b/lib/models_timm/layers/filter_response_norm.py @@ -0,0 +1,68 @@ +""" Filter Response Norm in PyTorch + +Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +import torch.nn as nn + +from .create_act import create_act_layer +from .trace_utils import _assert + + +def inv_instance_rms(x, eps: float = 1e-5): + rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) + return rms.expand(x.shape) + + +class FilterResponseNormTlu2d(nn.Module): + def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): + super(FilterResponseNormTlu2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.rms = rms + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.tau is not None: + nn.init.zeros_(self.tau) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + x = x * inv_instance_rms(x, self.eps) + x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x + + +class FilterResponseNormAct2d(nn.Module): + def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): + super(FilterResponseNormAct2d, self).__init__() + if act_layer is not None and apply_act: + self.act = create_act_layer(act_layer, inplace=inplace) + else: + self.act = nn.Identity() + self.rms = rms + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + x = x * inv_instance_rms(x, self.eps) + x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return self.act(x) diff --git a/lib/models_timm/layers/gather_excite.py b/lib/models_timm/layers/gather_excite.py new file mode 100644 index 0000000..2d60dc9 --- /dev/null +++ b/lib/models_timm/layers/gather_excite.py @@ -0,0 +1,90 @@ +""" Gather-Excite Attention Block + +Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 + +Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet + +I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another +impl that covers all of the cases. + +NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math + +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .create_conv2d import create_conv2d +from .helpers import make_divisible +from .mlp import ConvMlp + + +class GatherExcite(nn.Module): + """ Gather-Excite Attention Module + """ + def __init__( + self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, + rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): + super(GatherExcite, self).__init__() + self.add_maxpool = add_maxpool + act_layer = get_act_layer(act_layer) + self.extent = extent + if extra_params: + self.gather = nn.Sequential() + if extent == 0: + assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' + self.gather.add_module( + 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) + else: + assert extent % 2 == 0 + num_conv = int(math.log2(extent)) + for i in range(num_conv): + self.gather.add_module( + f'conv{i + 1}', + create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) + if i != num_conv - 1: + self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) + else: + self.gather = None + if self.extent == 0: + self.gk = 0 + self.gs = 0 + else: + assert extent % 2 == 0 + self.gk = self.extent * 2 - 1 + self.gs = self.extent + + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + size = x.shape[-2:] + if self.gather is not None: + x_ge = self.gather(x) + else: + if self.extent == 0: + # global extent + x_ge = x.mean(dim=(2, 3), keepdims=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) + else: + x_ge = F.avg_pool2d( + x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) + x_ge = self.mlp(x_ge) + if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: + x_ge = F.interpolate(x_ge, size=size) + return x * self.gate(x_ge) diff --git a/lib/models_timm/layers/global_context.py b/lib/models_timm/layers/global_context.py new file mode 100644 index 0000000..de7fb5c --- /dev/null +++ b/lib/models_timm/layers/global_context.py @@ -0,0 +1,67 @@ +""" Global Context Attention Block + +Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` + - https://arxiv.org/abs/1904.11492 + +Official code consulted as reference: https://github.com/xvjiarui/GCNet + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible +from .mlp import ConvMlp +from .norm import LayerNorm2d + + +class GlobalContext(nn.Module): + + def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, + rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): + super(GlobalContext, self).__init__() + act_layer = get_act_layer(act_layer) + + self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None + + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + if fuse_add: + self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_add = None + if fuse_scale: + self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_scale = None + + self.gate = create_act_layer(gate_layer) + self.init_last_zero = init_last_zero + self.reset_parameters() + + def reset_parameters(self): + if self.conv_attn is not None: + nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') + if self.mlp_add is not None: + nn.init.zeros_(self.mlp_add.fc2.weight) + + def forward(self, x): + B, C, H, W = x.shape + + if self.conv_attn is not None: + attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) + attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) + context = x.reshape(B, C, H * W).unsqueeze(1) @ attn + context = context.view(B, C, 1, 1) + else: + context = x.mean(dim=(2, 3), keepdim=True) + + if self.mlp_scale is not None: + mlp_x = self.mlp_scale(context) + x = x * self.gate(mlp_x) + if self.mlp_add is not None: + mlp_x = self.mlp_add(context) + x = x + mlp_x + + return x diff --git a/lib/models_timm/layers/halo_attn.py b/lib/models_timm/layers/halo_attn.py new file mode 100644 index 0000000..f2ac64f --- /dev/null +++ b/lib/models_timm/layers/halo_attn.py @@ -0,0 +1,233 @@ +""" Halo Self Attention + +Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` + - https://arxiv.org/abs/2103.12731 + +@misc{2103.12731, +Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and + Jonathon Shlens}, +Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones}, +Year = {2021}, +} + +Status: +This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me. +The attention mechanism works but it's slow as implemented. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import List + +import torch +from torch import nn +import torch.nn.functional as F + +from .helpers import make_divisible +from .weight_init import trunc_normal_ +from .trace_utils import _assert + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, height, width, dim) + rel_k: (2 * window - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + rel_size = rel_k.shape[0] + win_size = (rel_size + 1) // 2 + + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, rel_size) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, rel_size - W]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, rel_size) + x = x_pad[:, :W, win_size - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + """ + def __init__(self, block_size, win_size, dim_head, scale): + """ + Args: + block_size (int): block size + win_size (int): neighbourhood window size + dim_head (int): attention head dim + scale (float): scale factor (for init) + """ + super().__init__() + self.block_size = block_size + self.dim_head = dim_head + self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) + self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) + + def forward(self, q): + B, BB, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(-1, self.block_size, self.block_size, self.dim_head) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, BB, HW, -1) + return rel_logits + + +class HaloAttn(nn.Module): + """ Halo Attention + + Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` + - https://arxiv.org/abs/2103.12731 + + The internal dimensions of the attention module are controlled by the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query and key (qk) dimensions are determined by + * num_heads * dim_head if dim_head is not None + * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda) + stride: output stride of the module, query downscaled if > 1 (default: 1). + num_heads: parallel attention heads (default: 8). + dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + block_size (int): size of blocks. (default: 8) + halo_size (int): size of halo overlap. (default: 3) + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool) : add bias to q, k, and v projections + avg_down (bool): use average pool downsample instead of strided query blocks + scale_pos_embed (bool): scale the position embedding as well as Q @ K + """ + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, + qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False): + super().__init__() + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + assert stride in (1, 2) + self.num_heads = num_heads + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v + self.scale = self.dim_head_qk ** -0.5 + self.scale_pos_embed = scale_pos_embed + self.block_size = self.block_size_ds = block_size + self.halo_size = halo_size + self.win_size = block_size + halo_size * 2 # neighbourhood window size + self.block_stride = 1 + use_avg_pool = False + if stride > 1: + use_avg_pool = avg_down or block_size % stride != 0 + self.block_stride = 1 if use_avg_pool else stride + self.block_size_ds = self.block_size // self.block_stride + + # FIXME not clear if this stride behaviour is what the paper intended + # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving + # data in unfolded block form. I haven't wrapped my head around how that'd look. + self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias) + self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias) + + self.pos_embed = PosEmbedRel( + block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity() + + self.reset_parameters() + + def reset_parameters(self): + std = self.q.weight.shape[1] ** -0.5 # fan-in + trunc_normal_(self.q.weight, std=std) + trunc_normal_(self.kv.weight, std=std) + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def forward(self, x): + B, C, H, W = x.shape + _assert(H % self.block_size == 0, '') + _assert(W % self.block_size == 0, '') + num_h_blocks = H // self.block_size + num_w_blocks = W // self.block_size + num_blocks = num_h_blocks * num_w_blocks + + q = self.q(x) + # unfold + q = q.reshape( + -1, self.dim_head_qk, + num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4) + # B, num_heads * dim_head * block_size ** 2, num_blocks + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3) + # B * num_heads, num_blocks, block_size ** 2, dim_head + + kv = self.kv(x) + # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not + # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach. + # FIXME figure out how to switch impl between this and conv2d if XLA being used. + kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]) + kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape( + B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1) + k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1) + # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v + + if self.scale_pos_embed: + attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale + else: + attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q) + # B * num_heads, num_blocks, block_size ** 2, win_size ** 2 + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks + # fold + out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks) + out = out.permute(0, 3, 1, 4, 2).contiguous().view( + B, self.dim_out_v, H // self.block_stride, W // self.block_stride) + # B, dim_out, H // block_stride, W // block_stride + out = self.pool(out) + return out + + +""" Three alternatives for overlapping windows. + +`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold() + + if is_xla: + # This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is + # EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment. + WW = self.win_size ** 2 + pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size) + kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size) + elif self.stride_tricks: + kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() + kv = kv.as_strided(( + B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), + stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) + else: + kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + + kv = kv.reshape( + B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3) +""" diff --git a/lib/models_timm/layers/helpers.py b/lib/models_timm/layers/helpers.py new file mode 100644 index 0000000..2fa296b --- /dev/null +++ b/lib/models_timm/layers/helpers.py @@ -0,0 +1,43 @@ +""" Layer/Module Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +from itertools import repeat +import collections.abc + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def make_divisible(v, divisor=8, min_value=None, round_limit=.9): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < round_limit * v: + new_v += divisor + return new_v + + +def extend_tuple(x, n): + # pdas a tuple to specified n by padding with last value + if not isinstance(x, (tuple, list)): + x = (x,) + else: + x = tuple(x) + pad_n = n - len(x) + if pad_n <= 0: + return x[:n] + return x + (x[-1],) * pad_n diff --git a/lib/models_timm/layers/inplace_abn.py b/lib/models_timm/layers/inplace_abn.py new file mode 100644 index 0000000..a808893 --- /dev/null +++ b/lib/models_timm/layers/inplace_abn.py @@ -0,0 +1,87 @@ +import torch +from torch import nn as nn + +try: + from inplace_abn.functions import inplace_abn, inplace_abn_sync + has_iabn = True +except ImportError: + has_iabn = False + + def inplace_abn(x, weight, bias, running_mean, running_var, + training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): + raise ImportError( + "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") + + def inplace_abn_sync(**kwargs): + inplace_abn(**kwargs) + + +class InplaceAbn(nn.Module): + """Activated Batch Normalization + + This gathers a BatchNorm and an activation function in a single module + + Parameters + ---------- + num_features : int + Number of feature channels in the input and output. + eps : float + Small constant to prevent numerical issues. + momentum : float + Momentum factor applied to compute running statistics. + affine : bool + If `True` apply learned scale and shift transformation after normalization. + act_layer : str or nn.Module type + Name or type of the activation functions, one of: `leaky_relu`, `elu` + act_param : float + Negative slope for the `leaky_relu` activation. + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, + act_layer="leaky_relu", act_param=0.01, drop_layer=None): + super(InplaceAbn, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + self.momentum = momentum + if apply_act: + if isinstance(act_layer, str): + assert act_layer in ('leaky_relu', 'elu', 'identity', '') + self.act_name = act_layer if act_layer else 'identity' + else: + # convert act layer passed as type to string + if act_layer == nn.ELU: + self.act_name = 'elu' + elif act_layer == nn.LeakyReLU: + self.act_name = 'leaky_relu' + elif act_layer is None or act_layer == nn.Identity: + self.act_name = 'identity' + else: + assert False, f'Invalid act layer {act_layer.__name__} for IABN' + else: + self.act_name = 'identity' + self.act_param = act_param + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.running_mean, 0) + nn.init.constant_(self.running_var, 1) + if self.affine: + nn.init.constant_(self.weight, 1) + nn.init.constant_(self.bias, 0) + + def forward(self, x): + output = inplace_abn( + x, self.weight, self.bias, self.running_mean, self.running_var, + self.training, self.momentum, self.eps, self.act_name, self.act_param) + if isinstance(output, tuple): + output = output[0] + return output diff --git a/lib/models_timm/layers/lambda_layer.py b/lib/models_timm/layers/lambda_layer.py new file mode 100644 index 0000000..e50b43c --- /dev/null +++ b/lib/models_timm/layers/lambda_layer.py @@ -0,0 +1,133 @@ +""" Lambda Layer + +Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` + - https://arxiv.org/abs/2102.08602 + +@misc{2102.08602, +Author = {Irwan Bello}, +Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention}, +Year = {2021}, +} + +Status: +This impl is a WIP. Code snippets in the paper were used as reference but +good chance some details are missing/wrong. + +I've only implemented local lambda conv based pos embeddings. + +For a PyTorch impl that includes other embedding options checkout +https://github.com/lucidrains/lambda-networks + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +from torch import nn +import torch.nn.functional as F + +from .helpers import to_2tuple, make_divisible +from .weight_init import trunc_normal_ + + +def rel_pos_indices(size): + size = to_2tuple(size) + pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) + rel_pos = pos[:, None, :] - pos[:, :, None] + rel_pos[0] += size[0] - 1 + rel_pos[1] += size[1] - 1 + return rel_pos # 2, H * W, H * W + + +class LambdaLayer(nn.Module): + """Lambda Layer + + Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` + - https://arxiv.org/abs/2102.08602 + + NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. + + The internal dimensions of the lambda module are controlled via the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query (q) and key (k) dimension are determined by + * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None + * q = num_heads * dim_head, k = dim_head + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W + stride (int): output stride of the module, avg pool used if stride == 2 + num_heads (int): parallel attention heads. + dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool): add bias to q, k, and v projections + """ + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9, + qk_ratio=1.0, qkv_bias=False): + super().__init__() + dim_out = dim_out or dim + assert dim_out % num_heads == 0, ' should be divided by num_heads' + self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.num_heads = num_heads + self.dim_v = dim_out // num_heads + + self.qkv = nn.Conv2d( + dim, + num_heads * self.dim_qk + self.dim_qk + self.dim_v, + kernel_size=1, bias=qkv_bias) + self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk) + self.norm_v = nn.BatchNorm2d(self.dim_v) + + if r is not None: + # local lambda convolution for pos + self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0)) + self.pos_emb = None + self.rel_pos_indices = None + else: + # relative pos embedding + assert feat_size is not None + feat_size = to_2tuple(feat_size) + rel_size = [2 * s - 1 for s in feat_size] + self.conv_lambda = None + self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk)) + self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + self.reset_parameters() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in + if self.conv_lambda is not None: + trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5) + if self.pos_emb is not None: + trunc_normal_(self.pos_emb, std=.02) + + def forward(self, x): + B, C, H, W = x.shape + M = H * W + qkv = self.qkv(x) + q, k, v = torch.split(qkv, [ + self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1) + q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K + v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V + k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M + + content_lam = k @ v # B, K, V + content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V + + if self.pos_emb is None: + position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K + position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + else: + # FIXME relative pos embedding path not fully verified + pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) + position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V + position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V + + out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W + out = self.pool(out) + return out diff --git a/lib/models_timm/layers/linear.py b/lib/models_timm/layers/linear.py new file mode 100644 index 0000000..38fe338 --- /dev/null +++ b/lib/models_timm/layers/linear.py @@ -0,0 +1,19 @@ +""" Linear layer (alternate definition) +""" +import torch +import torch.nn.functional as F +from torch import nn as nn + + +class Linear(nn.Linear): + r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` + + Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting + weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. + """ + def forward(self, input: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None + return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) + else: + return F.linear(input, self.weight, self.bias) diff --git a/lib/models_timm/layers/median_pool.py b/lib/models_timm/layers/median_pool.py new file mode 100644 index 0000000..40bd71a --- /dev/null +++ b/lib/models_timm/layers/median_pool.py @@ -0,0 +1,49 @@ +""" Median Pool +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F +from .helpers import to_2tuple, to_4tuple + + +class MedianPool2d(nn.Module): + """ Median pool (usable as median filter when stride=1) module. + + Args: + kernel_size: size of pooling kernel, int or 2-tuple + stride: pool stride, int or 2-tuple + padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad + same: override padding and enforce same padding, boolean + """ + def __init__(self, kernel_size=3, stride=1, padding=0, same=False): + super(MedianPool2d, self).__init__() + self.k = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + self.padding = to_4tuple(padding) # convert to l, r, t, b + self.same = same + + def _padding(self, x): + if self.same: + ih, iw = x.size()[2:] + if ih % self.stride[0] == 0: + ph = max(self.k[0] - self.stride[0], 0) + else: + ph = max(self.k[0] - (ih % self.stride[0]), 0) + if iw % self.stride[1] == 0: + pw = max(self.k[1] - self.stride[1], 0) + else: + pw = max(self.k[1] - (iw % self.stride[1]), 0) + pl = pw // 2 + pr = pw - pl + pt = ph // 2 + pb = ph - pt + padding = (pl, pr, pt, pb) + else: + padding = self.padding + return padding + + def forward(self, x): + x = F.pad(x, self._padding(x), mode='reflect') + x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) + x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] + return x diff --git a/lib/models_timm/layers/mixed_conv2d.py b/lib/models_timm/layers/mixed_conv2d.py new file mode 100644 index 0000000..fa0ce56 --- /dev/null +++ b/lib/models_timm/layers/mixed_conv2d.py @@ -0,0 +1,51 @@ +""" PyTorch Mixed Convolution + +Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv2d_same import create_conv2d_pad + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = in_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x diff --git a/lib/models_timm/layers/ml_decoder.py b/lib/models_timm/layers/ml_decoder.py new file mode 100644 index 0000000..3f828c6 --- /dev/null +++ b/lib/models_timm/layers/ml_decoder.py @@ -0,0 +1,156 @@ +from typing import Optional + +import torch +from torch import nn +from torch import nn, Tensor +from torch.nn.modules.transformer import _get_activation_fn + + +def add_ml_decoder_head(model): + if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50 + model.global_pool = nn.Identity() + del model.fc + num_classes = model.num_classes + num_features = model.num_features + model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet + model.global_pool = nn.Identity() + del model.classifier + num_classes = model.num_classes + num_features = model.num_features + model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head') + del model.head + num_classes = model.num_classes + num_features = model.num_features + model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + else: + print("Model code-writing is not aligned currently with ml-decoder") + exit(-1) + if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout + model.drop_rate = 0 + return model + + +class TransformerDecoderLayerOptimal(nn.Module): + def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu", + layer_norm_eps=1e-5) -> None: + super(TransformerDecoderLayerOptimal, self).__init__() + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = torch.nn.functional.relu + super(TransformerDecoderLayerOptimal, self).__setstate__(state) + + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + tgt = tgt + self.dropout1(tgt) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(tgt, memory, memory)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + +# @torch.jit.script +# class ExtrapClasses(object): +# def __init__(self, num_queries: int, group_size: int): +# self.num_queries = num_queries +# self.group_size = group_size +# +# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap: +# torch.Tensor): +# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size) +# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups]) +# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size)) +# out = (h * w).sum(dim=2) + class_embed_b +# out = out.view((h.shape[0], self.group_size * self.num_queries)) +# return out + +@torch.jit.script +class GroupFC(object): + def __init__(self, embed_len_decoder: int): + self.embed_len_decoder = embed_len_decoder + + def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): + for i in range(self.embed_len_decoder): + h_i = h[:, i, :] + w_i = duplicate_pooling[i, :, :] + out_extrap[:, i, :] = torch.matmul(h_i, w_i) + + +class MLDecoder(nn.Module): + def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048): + super(MLDecoder, self).__init__() + embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups + if embed_len_decoder > num_classes: + embed_len_decoder = num_classes + + # switching to 768 initial embeddings + decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding + self.embed_standart = nn.Linear(initial_num_features, decoder_embedding) + + # decoder + decoder_dropout = 0.1 + num_layers_decoder = 1 + dim_feedforward = 2048 + layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding, + dim_feedforward=dim_feedforward, dropout=decoder_dropout) + self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder) + + # non-learnable queries + self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding) + self.query_embed.requires_grad_(False) + + # group fully-connected + self.num_classes = num_classes + self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999) + self.duplicate_pooling = torch.nn.Parameter( + torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor)) + self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes)) + torch.nn.init.xavier_normal_(self.duplicate_pooling) + torch.nn.init.constant_(self.duplicate_pooling_bias, 0) + self.group_fc = GroupFC(embed_len_decoder) + + def forward(self, x): + if len(x.shape) == 4: # [bs,2048, 7,7] + embedding_spatial = x.flatten(2).transpose(1, 2) + else: # [bs, 197,468] + embedding_spatial = x + embedding_spatial_786 = self.embed_standart(embedding_spatial) + embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True) + + bs = embedding_spatial_786.shape[0] + query_embed = self.query_embed.weight + # tgt = query_embed.unsqueeze(1).repeat(1, bs, 1) + tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand + h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768] + h = h.transpose(0, 1) + + out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) + self.group_fc(h, self.duplicate_pooling, out_extrap) + h_out = out_extrap.flatten(1)[:, :self.num_classes] + h_out += self.duplicate_pooling_bias + logits = h_out + return logits diff --git a/lib/models_timm/layers/mlp.py b/lib/models_timm/layers/mlp.py new file mode 100644 index 0000000..91e80a8 --- /dev/null +++ b/lib/models_timm/layers/mlp.py @@ -0,0 +1,126 @@ +""" MLP module w/ dropout and configurable activation layer + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .helpers import to_2tuple + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class GluMlp(nn.Module): + """ MLP w/ GLU style gating + See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + assert hidden_features % 2 == 0 + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + fc1_mid = self.fc1.bias.shape[0] // 2 + nn.init.ones_(self.fc1.bias[fc1_mid:]) + nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x, gates = x.chunk(2, dim=-1) + x = x * self.act(gates) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class GatedMlp(nn.Module): + """ MLP as used in gMLP + """ + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + gate_layer=None, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + if gate_layer is not None: + assert hidden_features % 2 == 0 + self.gate = gate_layer(hidden_features) + hidden_features = hidden_features // 2 # FIXME base reduction on gate property? + else: + self.gate = nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.gate(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class ConvMlp(nn.Module): + """ MLP using 1x1 convs that keeps spatial dims + """ + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, + norm_layer=None, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.act = act_layer() + self.drop = nn.Dropout(drop) + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x diff --git a/lib/models_timm/layers/non_local_attn.py b/lib/models_timm/layers/non_local_attn.py new file mode 100644 index 0000000..670e8f2 --- /dev/null +++ b/lib/models_timm/layers/non_local_attn.py @@ -0,0 +1,145 @@ +""" Bilinear-Attention-Transform and Non-Local Attention + +Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms` + - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html +Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification +""" +import torch +from torch import nn +from torch.nn import functional as F + +from .conv_bn_act import ConvNormAct +from .helpers import make_divisible +from .trace_utils import _assert + + +class NonLocalAttn(nn.Module): + """Spatial NL block for image classification. + + This was adapted from https://github.com/BA-Transform/BAT-Image-Classification + Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net. + """ + + def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs): + super(NonLocalAttn, self).__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.scale = in_channels ** -0.5 if use_scale else 1.0 + self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True) + self.norm = nn.BatchNorm2d(in_channels) + self.reset_parameters() + + def forward(self, x): + shortcut = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + B, C, H, W = t.size() + t = t.view(B, C, -1).permute(0, 2, 1) + p = p.view(B, C, -1) + g = g.view(B, C, -1).permute(0, 2, 1) + + att = torch.bmm(t, p) * self.scale + att = F.softmax(att, dim=2) + x = torch.bmm(att, g) + + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.z(x) + x = self.norm(x) + shortcut + + return x + + def reset_parameters(self): + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + if len(list(m.parameters())) > 1: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + + +class BilinearAttnTransform(nn.Module): + + def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(BilinearAttnTransform, self).__init__() + + self.conv1 = ConvNormAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) + self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1)) + self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size)) + self.conv2 = ConvNormAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.block_size = block_size + self.groups = groups + self.in_channels = in_channels + + def resize_mat(self, x, t: int): + B, C, block_size, block_size1 = x.shape + _assert(block_size == block_size1, '') + if t <= 1: + return x + x = x.view(B * C, -1, 1, 1) + x = x * torch.eye(t, t, dtype=x.dtype, device=x.device) + x = x.view(B * C, block_size, block_size, t, t) + x = torch.cat(torch.split(x, 1, dim=1), dim=3) + x = torch.cat(torch.split(x, 1, dim=2), dim=4) + x = x.view(B, C, block_size * t, block_size * t) + return x + + def forward(self, x): + _assert(x.shape[-1] % self.block_size == 0, '') + _assert(x.shape[-2] % self.block_size == 0, '') + B, C, H, W = x.shape + out = self.conv1(x) + rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) + cp = F.adaptive_max_pool2d(out, (1, self.block_size)) + p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid() + q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid() + p = p / p.sum(dim=3, keepdim=True) + q = q / q.sum(dim=2, keepdim=True) + p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + p = p.view(B, C, self.block_size, self.block_size) + q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + q = q.view(B, C, self.block_size, self.block_size) + p = self.resize_mat(p, H // self.block_size) + q = self.resize_mat(q, W // self.block_size) + y = p.matmul(x) + y = y.matmul(q) + + y = self.conv2(y) + return y + + +class BatNonLocalAttn(nn.Module): + """ BAT + Adapted from: https://github.com/BA-Transform/BAT-Image-Classification + """ + + def __init__( + self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, + drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_): + super().__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.conv1 = ConvNormAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer) + self.conv2 = ConvNormAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.dropout = nn.Dropout2d(p=drop_rate) + + def forward(self, x): + xl = self.conv1(x) + y = self.ba(xl) + y = self.conv2(y) + y = self.dropout(y) + return y + x diff --git a/lib/models_timm/layers/norm.py b/lib/models_timm/layers/norm.py new file mode 100644 index 0000000..42445a4 --- /dev/null +++ b/lib/models_timm/layers/norm.py @@ -0,0 +1,117 @@ +""" Normalization layers and wrappers + +Norm layer definitions that support fast norm and consistent channel arg order (always first arg). + +Hacked together by / Copyright 2022 Ross Wightman +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm + + +class GroupNorm(nn.GroupNorm): + def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): + # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN + super().__init__(num_groups, num_channels, eps=eps, affine=affine) + self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + def forward(self, x): + if self.fast_norm: + return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + + +class GroupNorm1(nn.GroupNorm): + """ Group Normalization with 1 group. + Input: tensor in shape [B, C, *] + """ + + def __init__(self, num_channels, **kwargs): + super().__init__(1, num_channels, **kwargs) + self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.fast_norm: + return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + + +class LayerNorm(nn.LayerNorm): + """ LayerNorm w/ fast norm option + """ + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x + + +class LayerNorm2d(nn.LayerNorm): + """ LayerNorm for channels of '2D' spatial NCHW tensors """ + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) + return x + + +def _is_contiguous(tensor: torch.Tensor) -> bool: + # jit is oh so lovely :/ + if torch.jit.is_scripting(): + return tensor.is_contiguous() + else: + return tensor.is_contiguous(memory_format=torch.contiguous_format) + + +@torch.jit.script +def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): + s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) + x = (x - u) * torch.rsqrt(s + eps) + x = x * weight[:, None, None] + bias[:, None, None] + return x + + +def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): + u = x.mean(dim=1, keepdim=True) + s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0) + x = (x - u) * torch.rsqrt(s + eps) + x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1) + return x + + +class LayerNormExp2d(nn.LayerNorm): + """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). + + Experimental implementation w/ manual norm for tensors non-contiguous tensors. + + This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last + layout. However, benefits are not always clear and can perform worse on other GPUs. + """ + + def __init__(self, num_channels, eps=1e-6): + super().__init__(num_channels, eps=eps) + + def forward(self, x) -> torch.Tensor: + if _is_contiguous(x): + x = F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + else: + x = _layer_norm_cf(x, self.weight, self.bias, self.eps) + return x diff --git a/lib/models_timm/layers/norm_act.py b/lib/models_timm/layers/norm_act.py new file mode 100644 index 0000000..ff075fb --- /dev/null +++ b/lib/models_timm/layers/norm_act.py @@ -0,0 +1,252 @@ +""" Normalization + Activation Layers + +Provides Norm+Act fns for standard PyTorch norm layers such as +* BatchNorm +* GroupNorm +* LayerNorm + +This allows swapping with alternative layers that are natively both norm + act such as +* EvoNorm (evo_norm.py) +* FilterResponseNorm (filter_response_norm.py) +* InplaceABN (inplace_abn.py) + +Hacked together by / Copyright 2022 Ross Wightman +""" +from typing import Union, List, Optional, Any + +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .create_act import get_act_layer +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm +from .trace_utils import _assert + + +class BatchNormAct2d(nn.BatchNorm2d): + """BatchNorm + Activation + + This module performs BatchNorm + Activation in a manner that will remain backwards + compatible with weights trained with separate bn, act. This is why we inherit from BN + instead of composing it as a .bn member. + """ + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + device=None, + dtype=None + ): + try: + factory_kwargs = {'device': device, 'dtype': dtype} + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, + **factory_kwargs + ) + except TypeError: + # NOTE for backwards compat with old PyTorch w/o factory device/dtype support + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def forward(self, x): + # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing + _assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)') + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: # type: ignore[has-type] + self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type] + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + x = F.batch_norm( + x, + # If buffers are not to be tracked, ensure that they won't be updated + self.running_mean if not self.training or self.track_running_stats else None, + self.running_var if not self.training or self.track_running_stats else None, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) + x = self.drop(x) + x = self.act(x) + return x + + +class SyncBatchNormAct(nn.SyncBatchNorm): + # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254) + # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers + # but ONLY when used in conjunction with the timm conversion function below. + # Do not create this module directly or use the PyTorch conversion function. + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine + if hasattr(self, "drop"): + x = self.drop(x) + if hasattr(self, "act"): + x = self.act(x) + return x + + +def convert_sync_batchnorm(module, process_group=None): + # convert both BatchNorm and BatchNormAct layers to Synchronized variants + module_output = module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + if isinstance(module, BatchNormAct2d): + # convert timm norm + act layer + module_output = SyncBatchNormAct( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group=process_group, + ) + # set act and drop attr from the original module + module_output.act = module.act + module_output.drop = module.drop + else: + # convert standard BatchNorm layers + module_output = torch.nn.SyncBatchNorm( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group, + ) + if module.affine: + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + if hasattr(module, "qconfig"): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module(name, convert_sync_batchnorm(child, process_group)) + del module + return module_output + + +def _num_groups(num_channels, num_groups, group_size): + if group_size: + assert num_channels % group_size == 0 + return num_channels // group_size + return num_groups + + +class GroupNormAct(nn.GroupNorm): + # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args + def __init__( + self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + super(GroupNormAct, self).__init__( + _num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + self._fast_norm = is_fast_norm() + + def forward(self, x): + if self._fast_norm: + x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + x = self.drop(x) + x = self.act(x) + return x + + +class LayerNormAct(nn.LayerNorm): + def __init__( + self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + self._fast_norm = is_fast_norm() + + def forward(self, x): + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = self.drop(x) + x = self.act(x) + return x + + +class LayerNormAct2d(nn.LayerNorm): + def __init__( + self, num_channels, eps=1e-5, affine=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + self._fast_norm = is_fast_norm() + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) + x = self.drop(x) + x = self.act(x) + return x diff --git a/lib/models_timm/layers/padding.py b/lib/models_timm/layers/padding.py new file mode 100644 index 0000000..34afc37 --- /dev/null +++ b/lib/models_timm/layers/padding.py @@ -0,0 +1,56 @@ +""" Padding Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +from typing import List, Tuple + +import torch.nn.functional as F + + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution +def get_same_padding(x: int, k: int, s: int, d: int): + return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +# Can SAME padding for given args be done statically? +def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +# Dynamically pad input x with 'SAME' padding for conv with specified args +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): + ih, iw = x.size()[-2:] + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) + return x + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic diff --git a/lib/models_timm/layers/patch_embed.py b/lib/models_timm/layers/patch_embed.py new file mode 100644 index 0000000..be8740c --- /dev/null +++ b/lib/models_timm/layers/patch_embed.py @@ -0,0 +1,48 @@ +""" Image to Patch Embedding using Conv2d + +A convolution based approach to patchifying a 2D image w/ embedding projection. + +Based on the impl in https://github.com/google-research/vision_transformer + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .helpers import to_2tuple +from .trace_utils import _assert + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x diff --git a/lib/models_timm/layers/pool2d_same.py b/lib/models_timm/layers/pool2d_same.py new file mode 100644 index 0000000..4c2a1c4 --- /dev/null +++ b/lib/models_timm/layers/pool2d_same.py @@ -0,0 +1,73 @@ +""" AvgPool2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple, Optional + +from .helpers import to_2tuple +from .padding import pad_same, get_padding_value + + +def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + ceil_mode: bool = False, count_include_pad: bool = True): + # FIXME how to deal with count_include_pad vs not for external padding? + x = pad_same(x, kernel_size, stride) + return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + +class AvgPool2dSame(nn.AvgPool2d): + """ Tensorflow like 'SAME' wrapper for 2D average pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride) + return F.avg_pool2d( + x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) + + +def max_pool2d_same( + x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + dilation: List[int] = (1, 1), ceil_mode: bool = False): + x = pad_same(x, kernel_size, stride, value=-float('inf')) + return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) + + +class MaxPool2dSame(nn.MaxPool2d): + """ Tensorflow like 'SAME' wrapper for 2D max pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) + return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) + + +def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): + stride = stride or kernel_size + padding = kwargs.pop('padding', '') + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) + if is_dynamic: + if pool_type == 'avg': + return AvgPool2dSame(kernel_size, stride=stride, **kwargs) + elif pool_type == 'max': + return MaxPool2dSame(kernel_size, stride=stride, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' + else: + if pool_type == 'avg': + return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + elif pool_type == 'max': + return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' diff --git a/lib/models_timm/layers/pos_embed.py b/lib/models_timm/layers/pos_embed.py new file mode 100644 index 0000000..99a122a --- /dev/null +++ b/lib/models_timm/layers/pos_embed.py @@ -0,0 +1,207 @@ +import math +from typing import List, Tuple, Optional, Union + +import torch +from torch import nn as nn + + +def pixel_freq_bands( + num_bands: int, + max_freq: float = 224., + linear_bands: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +): + if linear_bands: + bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device) + else: + bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device) + return bands * torch.pi + + +def inv_freq_bands( + num_bands: int, + temperature: float = 100000., + step: int = 2, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> torch.Tensor: + inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)) + return inv_freq + + +def build_sincos2d_pos_embed( + feat_shape: List[int], + dim: int = 64, + temperature: float = 10000., + reverse_coord: bool = False, + interleave_sin_cos: bool = False, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None +) -> torch.Tensor: + """ + + Args: + feat_shape: + dim: + temperature: + reverse_coord: stack grid order W, H instead of H, W + interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos + dtype: + device: + + Returns: + + """ + assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding' + pos_dim = dim // 4 + bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device) + + if reverse_coord: + feat_shape = feat_shape[::-1] # stack W, H instead of H, W + grid = torch.stack( + torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1) + pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) + # FIXME add support for unflattened spatial dim? + + stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos + pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1) + return pos_emb + + +def build_fourier_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + num_bands: int = 64, + max_res: int = 224, + linear_bands: bool = False, + include_grid: bool = False, + concat_out: bool = True, + in_pixels: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> List[torch.Tensor]: + if bands is None: + if in_pixels: + bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device) + else: + bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device) + else: + if device is None: + device = bands.device + if dtype is None: + dtype = bands.dtype + + if in_pixels: + grid = torch.stack(torch.meshgrid( + [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) + else: + grid = torch.stack(torch.meshgrid( + [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) + grid = grid.unsqueeze(-1) + pos = grid * bands + + pos_sin, pos_cos = pos.sin(), pos.cos() + out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos) + # FIXME torchscript doesn't like multiple return types, probably need to always cat? + if concat_out: + out = torch.cat(out, dim=-1) + return out + + +class FourierEmbed(nn.Module): + + def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False): + super().__init__() + self.max_res = max_res + self.num_bands = num_bands + self.concat_grid = concat_grid + self.keep_spatial = keep_spatial + self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False) + + def forward(self, x): + B, C = x.shape[:2] + feat_shape = x.shape[2:] + emb = build_fourier_pos_embed( + feat_shape, + self.bands, + include_grid=self.concat_grid, + dtype=x.dtype, + device=x.device) + emb = emb.transpose(-1, -2).flatten(len(feat_shape)) + batch_expand = (B,) + (-1,) * (x.ndim - 1) + + # FIXME support nD + if self.keep_spatial: + x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1) + else: + x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1) + x = x.reshape(B, feat_shape.numel(), -1) + + return x + + +def rot(x): + return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) + + +def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): + return x * cos_emb + rot(x) * sin_emb + + +def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): + if isinstance(x, torch.Tensor): + x = [x] + return [t * cos_emb + rot(t) * sin_emb for t in x] + + +def apply_rot_embed_split(x: torch.Tensor, emb): + split = emb.shape[-1] // 2 + return x * emb[:, :split] + rot(x) * emb[:, split:] + + +def build_rotary_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + dim: int = 64, + max_freq: float = 224, + linear_bands: bool = False, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +): + """ + NOTE: shape arg should include spatial dim only + """ + feat_shape = torch.Size(feat_shape) + + sin_emb, cos_emb = build_fourier_pos_embed( + feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands, + concat_out=False, device=device, dtype=dtype) + N = feat_shape.numel() + sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1) + cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1) + return sin_emb, cos_emb + + +class RotaryEmbedding(nn.Module): + """ Rotary position embedding + + NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not + been well tested, and will likely change. It will be moved to its own file. + + The following impl/resources were referenced for this impl: + * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py + * https://blog.eleuther.ai/rotary-embeddings/ + """ + def __init__(self, dim, max_res=224, linear_bands: bool = False): + super().__init__() + self.dim = dim + self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False) + + def get_embed(self, shape: List[int]): + return build_rotary_pos_embed(shape, self.bands) + + def forward(self, x): + # assuming channel-first tensor where spatial dim are >= 2 + sin_emb, cos_emb = self.get_embed(x.shape[2:]) + return apply_rot_embed(x, sin_emb, cos_emb) diff --git a/lib/models_timm/layers/selective_kernel.py b/lib/models_timm/layers/selective_kernel.py new file mode 100644 index 0000000..3d71e3a --- /dev/null +++ b/lib/models_timm/layers/selective_kernel.py @@ -0,0 +1,119 @@ +""" Selective Kernel Convolution/Attention + +Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import nn as nn + +from .conv_bn_act import ConvNormActAa +from .helpers import make_divisible +from .trace_utils import _assert + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + """ Selective Kernel Attention Module + + Selective Kernel attention mechanism factored out into its own module. + + """ + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + _assert(x.shape[1] == self.num_paths, '') + x = x.sum(1).mean((2, 3), keepdim=True) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +class SelectiveKernel(nn.Module): + + def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, + rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_layer=None): + """ Selective Kernel Convolution Module + + As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. + + Largest change is the input split, which divides the input channels across each convolution path, this can + be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps + the parameter count from ballooning when the convolutions themselves don't have groups, but still provides + a noteworthy increase in performance over similar param count models without this attention layer. -Ross W + + Args: + in_channels (int): module input (feature) channel count + out_channels (int): module output (feature) channel count + kernel_size (int, list): kernel size for each convolution branch + stride (int): stride for convolutions + dilation (int): dilation for module as a whole, impacts dilation of each branch + groups (int): number of groups for each branch + rd_ratio (int, float): reduction factor for attention features + keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations + split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, + can be viewed as grouping by path, output expands to module out_channels count + act_layer (nn.Module): activation layer to use + norm_layer (nn.Module): batchnorm/norm layer to use + aa_layer (nn.Module): anti-aliasing module + drop_layer (nn.Module): spatial drop module in convs (drop block, etc) + """ + super(SelectiveKernel, self).__init__() + out_channels = out_channels or in_channels + kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + self.num_paths = len(kernel_size) + self.in_channels = in_channels + self.out_channels = out_channels + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths + groups = min(out_channels, groups) + + conv_kwargs = dict( + stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer, drop_layer=drop_layer) + self.paths = nn.ModuleList([ + ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) + + attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + x = torch.sum(x, dim=1) + return x diff --git a/lib/models_timm/layers/separable_conv.py b/lib/models_timm/layers/separable_conv.py new file mode 100644 index 0000000..c081e02 --- /dev/null +++ b/lib/models_timm/layers/separable_conv.py @@ -0,0 +1,76 @@ +""" Depthwise Separable Conv Modules + +Basic DWS convs. Other variations of DWS exist with batch norm or activations between the +DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import get_norm_act_layer + + +class SeparableConvNormAct(nn.Module): + """ Separable Conv w/ trailing Norm and Activation + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, + apply_act=True, drop_layer=None): + super(SeparableConvNormAct, self).__init__() + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} + self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) + + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + x = self.bn(x) + return x + + +SeparableConvBnAct = SeparableConvNormAct + + +class SeparableConv2d(nn.Module): + """ Separable Conv + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1): + super(SeparableConv2d, self).__init__() + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + return x diff --git a/lib/models_timm/layers/space_to_depth.py b/lib/models_timm/layers/space_to_depth.py new file mode 100644 index 0000000..a7e8e0b --- /dev/null +++ b/lib/models_timm/layers/space_to_depth.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + + +class SpaceToDepth(nn.Module): + def __init__(self, block_size=4): + super().__init__() + assert block_size == 4 + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) + return x + + +@torch.jit.script +class SpaceToDepthJit(object): + def __call__(self, x: torch.Tensor): + # assuming hard-coded that block_size==4 for acceleration + N, C, H, W = x.size() + x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) + return x + + +class SpaceToDepthModule(nn.Module): + def __init__(self, no_jit=False): + super().__init__() + if not no_jit: + self.op = SpaceToDepthJit() + else: + self.op = SpaceToDepth() + + def forward(self, x): + return self.op(x) + + +class DepthToSpace(nn.Module): + + def __init__(self, block_size): + super().__init__() + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) + x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) + return x diff --git a/lib/models_timm/layers/split_attn.py b/lib/models_timm/layers/split_attn.py new file mode 100644 index 0000000..ac54f89 --- /dev/null +++ b/lib/models_timm/layers/split_attn.py @@ -0,0 +1,84 @@ +""" Split Attention Conv2d (for ResNeSt Models) + +Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt + +Modified for torchscript compat, performance, and consistency with timm by Ross Wightman +""" +import torch +import torch.nn.functional as F +from torch import nn + +from .helpers import make_divisible + + +class RadixSoftmax(nn.Module): + def __init__(self, radix, cardinality): + super(RadixSoftmax, self).__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttn(nn.Module): + """Split-Attention (aka Splat) + """ + def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, + dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, + act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): + super(SplitAttn, self).__init__() + out_channels = out_channels or in_channels + self.radix = radix + mid_chs = out_channels * radix + if rd_channels is None: + attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) + else: + attn_chs = rd_channels * radix + + padding = kernel_size // 2 if padding is None else padding + self.conv = nn.Conv2d( + in_channels, mid_chs, kernel_size, stride, padding, dilation, + groups=groups * radix, bias=bias, **kwargs) + self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + self.act0 = act_layer(inplace=True) + self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) + self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() + self.act1 = act_layer(inplace=True) + self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) + self.rsoftmax = RadixSoftmax(radix, groups) + + def forward(self, x): + x = self.conv(x) + x = self.bn0(x) + x = self.drop(x) + x = self.act0(x) + + B, RC, H, W = x.shape + if self.radix > 1: + x = x.reshape((B, self.radix, RC // self.radix, H, W)) + x_gap = x.sum(dim=1) + else: + x_gap = x + x_gap = x_gap.mean((2, 3), keepdim=True) + x_gap = self.fc1(x_gap) + x_gap = self.bn1(x_gap) + x_gap = self.act1(x_gap) + x_attn = self.fc2(x_gap) + + x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) + if self.radix > 1: + out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) + else: + out = x * x_attn + return out.contiguous() diff --git a/lib/models_timm/layers/split_batchnorm.py b/lib/models_timm/layers/split_batchnorm.py new file mode 100644 index 0000000..830781b --- /dev/null +++ b/lib/models_timm/layers/split_batchnorm.py @@ -0,0 +1,75 @@ +""" Split BatchNorm + +A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through +a separate BN layer. The first split is passed through the parent BN layers with weight/bias +keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' +namespace. + +This allows easily removing the auxiliary BN layers after training to efficiently +achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, +'Disentangled Learning via An Auxiliary BN' + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn + + +class SplitBatchNorm2d(torch.nn.BatchNorm2d): + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True, num_splits=2): + super().__init__(num_features, eps, momentum, affine, track_running_stats) + assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' + self.num_splits = num_splits + self.aux_bn = nn.ModuleList([ + nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) + + def forward(self, input: torch.Tensor): + if self.training: # aux BN only relevant while training + split_size = input.shape[0] // self.num_splits + assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" + split_input = input.split(split_size) + x = [super().forward(split_input[0])] + for i, a in enumerate(self.aux_bn): + x.append(a(split_input[i + 1])) + return torch.cat(x, dim=0) + else: + return super().forward(input) + + +def convert_splitbn_model(module, num_splits=2): + """ + Recursively traverse module and its children to replace all instances of + ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. + Args: + module (torch.nn.Module): input module + num_splits: number of separate batchnorm layers to split input across + Example:: + >>> # model is an instance of torch.nn.Module + >>> model = timm.models.convert_splitbn_model(model, num_splits=2) + """ + mod = module + if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): + return module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + mod = SplitBatchNorm2d( + module.num_features, module.eps, module.momentum, module.affine, + module.track_running_stats, num_splits=num_splits) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + mod.num_batches_tracked = module.num_batches_tracked + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + for aux in mod.aux_bn: + aux.running_mean = module.running_mean.clone() + aux.running_var = module.running_var.clone() + aux.num_batches_tracked = module.num_batches_tracked.clone() + if module.affine: + aux.weight.data = module.weight.data.clone().detach() + aux.bias.data = module.bias.data.clone().detach() + for name, child in module.named_children(): + mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) + del module + return mod diff --git a/lib/models_timm/layers/squeeze_excite.py b/lib/models_timm/layers/squeeze_excite.py new file mode 100644 index 0000000..4fe568f --- /dev/null +++ b/lib/models_timm/layers/squeeze_excite.py @@ -0,0 +1,102 @@ +""" Squeeze-and-Excitation Channel Attention + +An SE implementation originally based on PyTorch SE-Net impl. +Has since evolved with additional functionality / configuration. + +Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 + +Also included is Effective Squeeze-Excitation (ESE). +Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn + +from .create_act import create_act_layer +from .helpers import make_divisible + + +class SEModule(nn.Module): + """ SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, + bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): + super(SEModule, self).__init__() + self.add_maxpool = add_maxpool + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias) + self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() + self.act = create_act_layer(act_layer, inplace=True) + self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.act(self.bn(x_se)) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +SqueezeExcite = SEModule # alias + + +class EffectiveSEModule(nn.Module): + """ 'Effective Squeeze-Excitation + From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + """ + def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): + super(EffectiveSEModule, self).__init__() + self.add_maxpool = add_maxpool + self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc(x_se) + return x * self.gate(x_se) + + +EffectiveSqueezeExcite = EffectiveSEModule # alias + + +class SqueezeExciteCl(nn.Module): + """ SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, + bias=True, act_layer=nn.ReLU, gate_layer='sigmoid'): + super().__init__() + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Linear(channels, rd_channels, bias=bias) + self.act = create_act_layer(act_layer, inplace=True) + self.fc2 = nn.Linear(rd_channels, channels, bias=bias) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((1, 2), keepdims=True) # FIXME avg dim [1:n-1], don't assume 2D NHWC + x_se = self.fc1(x_se) + x_se = self.act(x_se) + x_se = self.fc2(x_se) + return x * self.gate(x_se) \ No newline at end of file diff --git a/lib/models_timm/layers/std_conv.py b/lib/models_timm/layers/std_conv.py new file mode 100644 index 0000000..d896ba5 --- /dev/null +++ b/lib/models_timm/layers/std_conv.py @@ -0,0 +1,133 @@ +""" Convolution with Weight Standardization (StdConv and ScaledStdConv) + +StdConv: +@article{weightstandardization, + author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille}, + title = {Weight Standardization}, + journal = {arXiv preprint arXiv:1903.10520}, + year = {2019}, +} +Code: https://github.com/joe-siyuan-qiao/WeightStandardization + +ScaledStdConv: +Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 +Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +Hacked together by / copyright Ross Wightman, 2021. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .padding import get_padding, get_padding_value, pad_same + + +class StdConv2d(nn.Conv2d): + """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, padding=None, + dilation=1, groups=1, bias=False, eps=1e-6): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + self.eps = eps + + def forward(self, x): + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class StdConv2dSame(nn.Conv2d): + """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', + dilation=1, groups=1, bias=False, eps=1e-6): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.same_pad = is_dynamic + self.eps = eps + + def forward(self, x): + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class ScaledStdConv2d(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization. + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=None, + dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) + self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) + self.eps = eps + + def forward(self, x): + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, + weight=(self.gain * self.scale).view(-1), + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class ScaledStdConv2dSame(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', + dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) + self.scale = gamma * self.weight[0].numel() ** -0.5 + self.same_pad = is_dynamic + self.eps = eps + + def forward(self, x): + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, + weight=(self.gain * self.scale).view(-1), + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) diff --git a/lib/models_timm/layers/test_time_pool.py b/lib/models_timm/layers/test_time_pool.py new file mode 100644 index 0000000..5826d8c --- /dev/null +++ b/lib/models_timm/layers/test_time_pool.py @@ -0,0 +1,52 @@ +""" Test Time Pooling (Average-Max Pool) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import logging +from torch import nn +import torch.nn.functional as F + +from .adaptive_avgmax_pool import adaptive_avgmax_pool2d + + +_logger = logging.getLogger(__name__) + + +class TestTimePoolHead(nn.Module): + def __init__(self, base, original_pool=7): + super(TestTimePoolHead, self).__init__() + self.base = base + self.original_pool = original_pool + base_fc = self.base.get_classifier() + if isinstance(base_fc, nn.Conv2d): + self.fc = base_fc + else: + self.fc = nn.Conv2d( + self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) + self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) + self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) + self.base.reset_classifier(0) # delete original fc layer + + def forward(self, x): + x = self.base.forward_features(x) + x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) + x = self.fc(x) + x = adaptive_avgmax_pool2d(x, 1) + return x.view(x.size(0), -1) + + +def apply_test_time_pool(model, config, use_test_size=False): + test_time_pool = False + if not hasattr(model, 'default_cfg') or not model.default_cfg: + return model, False + if use_test_size and 'test_input_size' in model.default_cfg: + df_input_size = model.default_cfg['test_input_size'] + else: + df_input_size = model.default_cfg['input_size'] + if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: + _logger.info('Target input size %s > pretrained default %s, using test time pooling' % + (str(config['input_size'][-2:]), str(df_input_size[-2:]))) + model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) + test_time_pool = True + return model, test_time_pool diff --git a/lib/models_timm/layers/trace_utils.py b/lib/models_timm/layers/trace_utils.py new file mode 100644 index 0000000..8397072 --- /dev/null +++ b/lib/models_timm/layers/trace_utils.py @@ -0,0 +1,13 @@ +try: + from torch import _assert +except ImportError: + def _assert(condition: bool, message: str): + assert condition, message + + +def _float_to_int(x: float) -> int: + """ + Symbolic tracing helper to substitute for inbuilt `int`. + Hint: Inbuilt `int` can't accept an argument of type `Proxy` + """ + return int(x) diff --git a/lib/models_timm/layers/weight_init.py b/lib/models_timm/layers/weight_init.py new file mode 100644 index 0000000..943e4f4 --- /dev/null +++ b/lib/models_timm/layers/weight_init.py @@ -0,0 +1,125 @@ +import torch +import math +import warnings + +from torch.nn.init import _calculate_fan_in_and_fan_out + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are + applied while sampling the normal with mean/std applied, therefore a, b args + should be adjusted to match the range of mean, std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + with torch.no_grad(): + return _trunc_normal_(tensor, mean, std, a, b) + + +def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + return tensor + + +def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == 'fan_in': + denom = fan_in + elif mode == 'fan_out': + denom = fan_out + elif mode == 'fan_avg': + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')