From 850ab4931fd529899e5f82650825ee3028ff6879 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 10 May 2023 12:16:30 -0700 Subject: [PATCH] Missed a few pretrained tags... --- timm/models/dla.py | 2 +- timm/models/gcvit.py | 52 +++++++++++++++++++++---------------------- timm/models/pvt_v2.py | 14 ++++++------ 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/timm/models/dla.py b/timm/models/dla.py index e7c20dca03..3052819db7 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -415,7 +415,7 @@ def _cfg(url='', **kwargs): 'dla102.in1k': _cfg(hf_hub_id='timm/'), 'dla102x.in1k': _cfg(hf_hub_id='timm/'), 'dla102x2.in1k': _cfg(hf_hub_id='timm/'), - 'dla169': _cfg(hf_hub_id='timm/'), + 'dla169.in1k': _cfg(hf_hub_id='timm/'), 'dla60_res2net.in1k': _cfg(hf_hub_id='timm/'), 'dla60_res2next.in1k': _cfg(hf_hub_id='timm/'), }) diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index 29cdb1fbe0..29536a7dd2 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -33,36 +33,11 @@ from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import named_apply -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['GlobalContextVit'] -def _cfg(url='', **kwargs): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.conv1', 'classifier': 'head.fc', - 'fixed_input_size': True, - **kwargs - } - - -default_cfgs = { - 'gcvit_xxtiny': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xxtiny_224_nvidia-d1d86009.pth'), - 'gcvit_xtiny': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xtiny_224_nvidia-274b92b7.pth'), - 'gcvit_tiny': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_tiny_224_nvidia-ac783954.pth'), - 'gcvit_small': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_small_224_nvidia-4e98afa2.pth'), - 'gcvit_base': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_base_224_nvidia-f009139b.pth'), -} - - class MbConvBlock(nn.Module): """ A depthwise separable / fused mbconv style residual block with SE, `no norm. """ @@ -541,6 +516,31 @@ def _create_gcvit(variant, pretrained=False, **kwargs): return model +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', + 'fixed_input_size': True, + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'gcvit_xxtiny.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xxtiny_224_nvidia-d1d86009.pth'), + 'gcvit_xtiny.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xtiny_224_nvidia-274b92b7.pth'), + 'gcvit_tiny.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_tiny_224_nvidia-ac783954.pth'), + 'gcvit_small.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_small_224_nvidia-4e98afa2.pth'), + 'gcvit_base.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_base_224_nvidia-f009139b.pth'), +}) + + @register_model def gcvit_xxtiny(pretrained=False, **kwargs) -> GlobalContextVit: model_kwargs = dict( diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index e594e6d50f..00379b158a 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -448,13 +448,13 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ - 'pvt_v2_b0': _cfg(hf_hub_id='timm/'), - 'pvt_v2_b1': _cfg(hf_hub_id='timm/'), - 'pvt_v2_b2': _cfg(hf_hub_id='timm/'), - 'pvt_v2_b3': _cfg(hf_hub_id='timm/'), - 'pvt_v2_b4': _cfg(hf_hub_id='timm/'), - 'pvt_v2_b5': _cfg(hf_hub_id='timm/'), - 'pvt_v2_b2_li': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b0.in1k': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b1.in1k': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b2.in1k': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b3.in1k': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b4.in1k': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b5.in1k': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b2_li.in1k': _cfg(hf_hub_id='timm/'), })