Skip to content

Commit

Permalink
Merge pull request huggingface#1789 from huggingface/mw-final
Browse files Browse the repository at this point in the history
Final push to get remaining models using multi-weight pretrained configs and HF hub weights
  • Loading branch information
rwightman authored Apr 27, 2023
2 parents 9ee846f + 493c730 commit bd5f9a3
Show file tree
Hide file tree
Showing 37 changed files with 2,404 additions and 2,252 deletions.
12 changes: 4 additions & 8 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
EXCLUDE_FILTERS = ['*enormous*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*']

EXCLUDE_JIT_FILTERS = []

TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
TARGET_BWD_SIZE = 128
MAX_BWD_SIZE = 320
Expand Down Expand Up @@ -277,7 +279,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):


if 'GITHUB_ACTIONS' not in os.environ:
@pytest.mark.timeout(120)
@pytest.mark.timeout(240)
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_load_pretrained(model_name, batch_size):
Expand All @@ -286,19 +288,13 @@ def test_model_load_pretrained(model_name, batch_size):
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5)
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=0)

@pytest.mark.timeout(120)
@pytest.mark.timeout(240)
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_features_pretrained(model_name, batch_size):
"""Create that pretrained weights load when features_only==True."""
create_model(model_name, pretrained=True, features_only=True)

EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', 'ghostnet*' # hopefully fix at some point
'vit_large_*', 'vit_huge_*', 'vit_gi*',
]


@pytest.mark.torchscript
@pytest.mark.timeout(120)
Expand Down
4 changes: 4 additions & 0 deletions timm/layers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def create_classifier(
pool_type: str = 'avg',
use_conv: bool = False,
input_fmt: str = 'NCHW',
drop_rate: Optional[float] = None,
):
global_pool, num_pooled_features = _create_pool(
num_features,
Expand All @@ -65,6 +66,9 @@ def create_classifier(
num_classes,
use_conv=use_conv,
)
if drop_rate is not None:
dropout = nn.Dropout(drop_rate)
return global_pool, dropout, fc
return global_pool, fc


Expand Down
57 changes: 49 additions & 8 deletions timm/layers/conv_bn_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,41 @@

class ConvNormAct(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None):
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding='',
dilation=1,
groups=1,
bias=False,
apply_act=True,
norm_layer=nn.BatchNorm2d,
norm_kwargs=None,
act_layer=nn.ReLU,
act_kwargs=None,
drop_layer=None,
):
super(ConvNormAct, self).__init__()
norm_kwargs = norm_kwargs or {}
act_kwargs = act_kwargs or {}

self.conv = create_conv2d(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)

# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn = norm_act_layer(
out_channels,
apply_act=apply_act,
act_kwargs=act_kwargs,
**norm_kwargs,
)

@property
def in_channels(self):
Expand Down Expand Up @@ -57,10 +80,27 @@ def create_aa(aa_layer, channels, stride=2, enable=True):

class ConvNormActAa(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None):
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding='',
dilation=1,
groups=1,
bias=False,
apply_act=True,
norm_layer=nn.BatchNorm2d,
norm_kwargs=None,
act_layer=nn.ReLU,
act_kwargs=None,
aa_layer=None,
drop_layer=None,
):
super(ConvNormActAa, self).__init__()
use_aa = aa_layer is not None and stride == 2
norm_kwargs = norm_kwargs or {}
act_kwargs = act_kwargs or {}

self.conv = create_conv2d(
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
Expand All @@ -69,8 +109,9 @@ def __init__(
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs)
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)

@property
Expand Down
82 changes: 40 additions & 42 deletions timm/layers/norm_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@
from .trace_utils import _assert


def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
act_layer = get_act_layer(act_layer) # string -> nn.Module
act_kwargs = act_kwargs or {}
if act_layer is not None and apply_act:
if inplace:
act_kwargs['inplace'] = inplace
act = act_layer(**act_kwargs)
else:
act = nn.Identity()
return act


class BatchNormAct2d(nn.BatchNorm2d):
"""BatchNorm + Activation
Expand All @@ -40,31 +52,33 @@ def __init__(
track_running_stats=True,
apply_act=True,
act_layer=nn.ReLU,
act_params=None, # FIXME not the final approach
act_kwargs=None,
inplace=True,
drop_layer=None,
device=None,
dtype=None
dtype=None,
):
try:
factory_kwargs = {'device': device, 'dtype': dtype}
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats,
**factory_kwargs
num_features,
eps=eps,
momentum=momentum,
affine=affine,
track_running_stats=track_running_stats,
**factory_kwargs,
)
except TypeError:
# NOTE for backwards compat with old PyTorch w/o factory device/dtype support
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
num_features,
eps=eps,
momentum=momentum,
affine=affine,
track_running_stats=track_running_stats,
)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
if act_params is not None:
act_args['negative_slope'] = act_params
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)

def forward(self, x):
# cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
Expand Down Expand Up @@ -188,6 +202,7 @@ def __init__(
eps: float = 1e-5,
apply_act=True,
act_layer=nn.ReLU,
act_kwargs=None,
inplace=True,
drop_layer=None,
):
Expand All @@ -199,12 +214,7 @@ def __init__(
self.register_buffer("running_var", torch.ones(num_features))

self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)

def _load_from_state_dict(
self,
Expand Down Expand Up @@ -344,6 +354,7 @@ def __init__(
group_size=None,
apply_act=True,
act_layer=nn.ReLU,
act_kwargs=None,
inplace=True,
drop_layer=None,
):
Expand All @@ -354,12 +365,8 @@ def __init__(
affine=affine,
)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)

self._fast_norm = is_fast_norm()

def forward(self, x):
Expand All @@ -380,17 +387,14 @@ def __init__(
affine=True,
apply_act=True,
act_layer=nn.ReLU,
act_kwargs=None,
inplace=True,
drop_layer=None,
):
super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)

self._fast_norm = is_fast_norm()

def forward(self, x):
Expand All @@ -411,17 +415,15 @@ def __init__(
affine=True,
apply_act=True,
act_layer=nn.ReLU,
act_kwargs=None,
inplace=True,
drop_layer=None,
):
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)

self._fast_norm = is_fast_norm()

def forward(self, x):
Expand All @@ -442,17 +444,13 @@ def __init__(
affine=True,
apply_act=True,
act_layer=nn.ReLU,
act_kwargs=None,
inplace=True,
drop_layer=None,
):
super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
self._fast_norm = is_fast_norm()

def forward(self, x):
Expand Down
24 changes: 15 additions & 9 deletions timm/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class PatchEmbed(nn.Module):

def __init__(
self,
img_size: int = 224,
img_size: Optional[int] = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
Expand All @@ -39,12 +39,16 @@ def __init__(
bias: bool = True,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.patch_size = to_2tuple(patch_size)
if img_size is not None:
self.img_size = to_2tuple(img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.img_size = None
self.grid_size = None
self.num_patches = None

if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
Expand All @@ -58,8 +62,10 @@ def __init__(

def forward(self, x):
B, C, H, W = x.shape
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
if self.img_size is not None:
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")

x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
Expand Down
4 changes: 3 additions & 1 deletion timm/layers/space_to_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@


class SpaceToDepth(nn.Module):
bs: torch.jit.Final[int]

def __init__(self, block_size=4):
super().__init__()
assert block_size == 4
Expand All @@ -12,7 +14,7 @@ def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
return x


Expand Down
1 change: 0 additions & 1 deletion timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from .focalnet import *
from .gcvit import *
from .ghostnet import *
from .gluon_xception import *
from .hardcorenas import *
from .hrnet import *
from .inception_resnet_v2 import *
Expand Down
Loading

0 comments on commit bd5f9a3

Please sign in to comment.