Skip to content

Commit

Permalink
Merge pull request huggingface#1799 from huggingface/dot_nine_cleanup
Browse files Browse the repository at this point in the history
Final cleanup before .9 release
  • Loading branch information
rwightman authored May 10, 2023
2 parents 5cc87e6 + b9d43c7 commit c9db470
Show file tree
Hide file tree
Showing 73 changed files with 1,453 additions and 1,236 deletions.
2 changes: 1 addition & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same
from .patch_dropout import PatchDropout
from .patch_embed import PatchEmbed, resample_patch_embed
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
Expand Down
45 changes: 44 additions & 1 deletion timm/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
from typing import List, Optional, Callable
from typing import Callable, List, Optional, Tuple, Union

import torch
from torch import nn as nn
Expand Down Expand Up @@ -75,6 +75,49 @@ def forward(self, x):
return x


class PatchEmbedWithSize(PatchEmbed):
""" 2D Image to Patch Embedding
"""
output_fmt: Format

def __init__(
self,
img_size: Optional[int] = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
):
super().__init__(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer,
flatten=flatten,
output_fmt=output_fmt,
bias=bias,
)

def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
B, C, H, W = x.shape
if self.img_size is not None:
_assert(H % self.patch_size[0] == 0, f"Input image height ({H}) must be divisible by patch size ({self.patch_size[0]}).")
_assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")

x = self.proj(x)
grid_size = x.shape[-2:]
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
x = self.norm(x)
return x, grid_size


def resample_patch_embed(
patch_embed,
new_size: List[int],
Expand Down
24 changes: 13 additions & 11 deletions timm/layers/pos_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,31 @@ def resample_abs_pos_embed(
verbose: bool = False,
):
# sort out sizes, assume square if old size not provided
new_size = to_2tuple(new_size)
new_ntok = new_size[0] * new_size[1]
if not old_size:
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
old_size = to_2tuple(old_size)
if new_size == old_size: # might not both be same container type
num_pos_tokens = posemb.shape[1]
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
return posemb

if not old_size:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw

if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb

# do the interpolation
embed_dim = posemb.shape[-1]
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)

if verbose:
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)

# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
print(posemb_prefix.shape, posemb.shape)
posemb = torch.cat([posemb_prefix, posemb], dim=1)

if not torch.jit.is_scripting() and verbose:
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')

return posemb
27 changes: 19 additions & 8 deletions timm/models/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
_DOWNLOAD_PROGRESS = False
_CHECK_HASH = False

_USE_OLD_CACHE = int(os.environ.get('TIMM_USE_OLD_CACHE', 0)) > 0

__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
Expand All @@ -32,6 +32,7 @@ def _resolve_pretrained_source(pretrained_cfg):
cfg_source = pretrained_cfg.get('source', '')
pretrained_url = pretrained_cfg.get('url', None)
pretrained_file = pretrained_cfg.get('file', None)
pretrained_sd = pretrained_cfg.get('state_dict', None)
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)

# resolve where to load pretrained weights from
Expand All @@ -44,14 +45,21 @@ def _resolve_pretrained_source(pretrained_cfg):
pretrained_loc = hf_hub_id
else:
# default source == timm or unspecified
if pretrained_file:
# file load override is the highest priority if set
if pretrained_sd:
# direct state_dict pass through is the highest priority
load_from = 'state_dict'
pretrained_loc = pretrained_sd
assert isinstance(pretrained_loc, dict)
elif pretrained_file:
# file load override is the second-highest priority if set
load_from = 'file'
pretrained_loc = pretrained_file
else:
# next, HF hub is prioritized unless a valid cached version of weights exists already
cached_url_valid = check_cached_file(pretrained_url) if pretrained_url else False
if hf_hub_id and has_hf_hub(necessary=True) and not cached_url_valid:
old_cache_valid = False
if _USE_OLD_CACHE:
# prioritized old cached weights if exists and env var enabled
old_cache_valid = check_cached_file(pretrained_url) if pretrained_url else False
if not old_cache_valid and hf_hub_id and has_hf_hub(necessary=True):
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
Expand Down Expand Up @@ -106,7 +114,7 @@ def load_custom_pretrained(
if not load_from:
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
if load_from == 'hf-hub': # FIXME
if load_from == 'hf-hub':
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
elif load_from == 'url':
pretrained_loc = download_cached_file(
Expand Down Expand Up @@ -148,7 +156,10 @@ def load_pretrained(
return

load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if load_from == 'file':
if load_from == 'state_dict':
_logger.info(f'Loading pretrained weights from state dict')
state_dict = pretrained_loc # pretrained_loc is the actual state dict for this override
elif load_from == 'file':
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
state_dict = load_state_dict(pretrained_loc)
elif load_from == 'url':
Expand Down
11 changes: 6 additions & 5 deletions timm/models/_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
class PretrainedCfg:
"""
"""
# weight locations
url: Optional[Union[str, Tuple[str, str]]] = None
file: Optional[str] = None
hf_hub_id: Optional[str] = None
hf_hub_filename: Optional[str] = None
# weight source locations
url: Optional[Union[str, Tuple[str, str]]] = None # remote URL
file: Optional[str] = None # local / shared filesystem path
state_dict: Optional[Dict[str, Any]] = None # in-memory state dict
hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model')
hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default)

source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
architecture: Optional[str] = None # architecture variant can be set when not implicit
Expand Down
24 changes: 17 additions & 7 deletions timm/models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,11 @@ def _cfg(url='', **kwargs):
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_base_patch16_224.in1k_ft_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft1k.pth',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
hf_hub_id='timm/',
Expand All @@ -487,6 +492,11 @@ def _cfg(url='', **kwargs):
hf_hub_id='timm/',
crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224.in1k_ft_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft1k.pth',
hf_hub_id='timm/',
crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
hf_hub_id='timm/',
Expand Down Expand Up @@ -515,7 +525,7 @@ def _create_beit(variant, pretrained=False, **kwargs):


@register_model
def beit_base_patch16_224(pretrained=False, **kwargs):
def beit_base_patch16_224(pretrained=False, **kwargs) -> Beit:
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
Expand All @@ -524,7 +534,7 @@ def beit_base_patch16_224(pretrained=False, **kwargs):


@register_model
def beit_base_patch16_384(pretrained=False, **kwargs):
def beit_base_patch16_384(pretrained=False, **kwargs) -> Beit:
model_args = dict(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
Expand All @@ -533,7 +543,7 @@ def beit_base_patch16_384(pretrained=False, **kwargs):


@register_model
def beit_large_patch16_224(pretrained=False, **kwargs):
def beit_large_patch16_224(pretrained=False, **kwargs) -> Beit:
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
Expand All @@ -542,7 +552,7 @@ def beit_large_patch16_224(pretrained=False, **kwargs):


@register_model
def beit_large_patch16_384(pretrained=False, **kwargs):
def beit_large_patch16_384(pretrained=False, **kwargs) -> Beit:
model_args = dict(
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
Expand All @@ -551,7 +561,7 @@ def beit_large_patch16_384(pretrained=False, **kwargs):


@register_model
def beit_large_patch16_512(pretrained=False, **kwargs):
def beit_large_patch16_512(pretrained=False, **kwargs) -> Beit:
model_args = dict(
img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
Expand All @@ -560,7 +570,7 @@ def beit_large_patch16_512(pretrained=False, **kwargs):


@register_model
def beitv2_base_patch16_224(pretrained=False, **kwargs):
def beitv2_base_patch16_224(pretrained=False, **kwargs) -> Beit:
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
Expand All @@ -569,7 +579,7 @@ def beitv2_base_patch16_224(pretrained=False, **kwargs):


@register_model
def beitv2_large_patch16_224(pretrained=False, **kwargs):
def beitv2_large_patch16_224(pretrained=False, **kwargs) -> Beit:
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
Expand Down
Loading

0 comments on commit c9db470

Please sign in to comment.