Skip to content

Commit

Permalink
Cleanup before samvit merge. Resize abs posembed on the fly, undo som…
Browse files Browse the repository at this point in the history
…e line-wraps, remove redundant unbind, fix HF hub weight load
  • Loading branch information
rwightman committed May 18, 2023
1 parent c1c6eeb commit e9373b1
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 60 deletions.
2 changes: 1 addition & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .patch_dropout import PatchDropout
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
Expand Down
16 changes: 14 additions & 2 deletions timm/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
strict_img_size: bool = True,
):
super().__init__()
self.patch_size = to_2tuple(patch_size)
Expand All @@ -56,15 +57,26 @@ def __init__(
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

def forward(self, x):
B, C, H, W = x.shape
if self.img_size is not None:
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
if self.strict_img_size:
_assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
else:
_assert(
H % self.patch_size[0] == 0,
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
)
_assert(
W % self.patch_size[1] == 0,
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
)

x = self.proj(x)
if self.flatten:
Expand Down
21 changes: 21 additions & 0 deletions timm/layers/pos_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,24 @@ def resample_abs_pos_embed(
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')

return posemb


def resample_abs_pos_embed_nhwc(
posemb,
new_size: List[int],
interpolation: str = 'bicubic',
antialias: bool = True,
verbose: bool = False,
):
if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
return posemb

# do the interpolation
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1)

if not torch.jit.is_scripting() and verbose:
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')

return posemb
114 changes: 57 additions & 57 deletions timm/models/vision_transformer_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import torch.utils.checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, Format
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead,\
Format, resample_abs_pos_embed_nhwc
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
Expand Down Expand Up @@ -71,24 +72,21 @@ def __init__(

def forward(self, x):
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(
B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.unbind(0)
# qkv with shape (3, B, nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
# q, k, v with shape (B * nHead, H * W, C)
q, k = self.q_norm(q), self.k_norm(k)

attn = (q * self.scale) @ k.transpose(-2, -1)
q = q * self.scale
attn = q @ k.transpose(-2, -1)

if self.use_rel_pos:
attn = add_decomposed_rel_pos(
attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))

attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).view(B, self.num_heads, H, W, -
1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)

return x
Expand Down Expand Up @@ -136,13 +134,10 @@ def __init__(
proj_drop=proj_drop,
norm_layer=norm_layer,
use_rel_pos=use_rel_pos,
input_size=input_size if window_size == 0 else (
window_size, window_size),
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.ls1 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

self.norm2 = norm_layer(dim)
self.mlp = mlp_layer(
Expand All @@ -151,10 +146,8 @@ def __init__(
act_layer=act_layer,
drop=proj_drop,
)
self.ls2 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):
shortcut = x
Expand Down Expand Up @@ -194,10 +187,8 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w

x = x.view(B, Hp // window_size, window_size,
Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous(
).view(-1, window_size, window_size, C)
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)


Expand All @@ -218,8 +209,7 @@ def window_unpartition(
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size,
window_size, window_size, -1)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)

if Hp > H or Wp > W:
Expand Down Expand Up @@ -248,16 +238,14 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(
-1, max_rel_dist).permute(1, 0)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos

# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + \
(k_size - 1) * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

return rel_pos_resized[relative_coords.long()]

Expand Down Expand Up @@ -331,7 +319,7 @@ def __init__(
drop_path_rate: float = 0.,
weight_init: str = '',
embed_layer: Callable = partial(
PatchEmbed, output_fmt=Format.NHWC),
PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
norm_layer: Optional[Callable] = nn.LayerNorm,
act_layer: Optional[Callable] = nn.GELU,
block_fn: Callable = Block,
Expand All @@ -342,6 +330,7 @@ def __init__(
global_attn_indexes: Tuple[int, ...] = (),
neck_chans: int = 256,
global_pool: str = 'avg',
head_hidden_size: Optional[int] = None
):
"""
Args:
Expand Down Expand Up @@ -370,6 +359,7 @@ def __init__(
window_size: Window size for window attention blocks. If 0, not use window attention.
global_attn_indexes: Indexes for blocks using global attention. Used when window_size > 0.
global_pool: Global pooling type.
head_hidden_size: If set, use NormMlpHead
"""
super().__init__()
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
Expand All @@ -388,14 +378,12 @@ def __init__(
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used
)
grid_size = self.patch_embed.grid_size
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size,
img_size // patch_size, embed_dim)
)
self.pos_embed = nn.Parameter(torch.zeros(1, grid_size[0], grid_size[1], embed_dim))
else:
self.pos_embed = 0.
self.pos_embed = None
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
Expand Down Expand Up @@ -424,7 +412,7 @@ def __init__(
mlp_layer=mlp_layer,
use_rel_pos=use_rel_pos,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
input_size=grid_size,
)
for i in range(depth)])

Expand All @@ -451,12 +439,21 @@ def __init__(
neck_chans = embed_dim

# Classifier Head
self.head = ClassifierHead(
neck_chans,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)
if head_hidden_size:
self.head = NormMlpClassifierHead(
neck_chans,
num_classes,
hidden_size=head_hidden_size,
pool_type=global_pool,
drop_rate=drop_rate,
)
else:
self.head = ClassifierHead(
neck_chans,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)

@torch.jit.ignore
def no_weight_decay(self):
Expand All @@ -478,15 +475,14 @@ def get_classifier(self):
return self.head

def reset_classifier(self, num_classes=0, global_pool=None):
self.head = self.head.reset(num_classes, global_pool) if num_classes > 0 else nn.Identity()

def _pos_embed(self, x):
x = x + self.pos_embed
return self.pos_drop(x)
self.head.reset(num_classes, global_pool)

def forward_features(self, x):
x = self.patch_embed(x)
x = self._pos_embed(x)
if self.pos_embed is not None:
# dynamically resize abs pos embedding if needed
x = x + resample_abs_pos_embed_nhwc(self.pos_embed, x.shape[1:3])
x = self.pos_drop(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
Expand All @@ -507,15 +503,19 @@ def forward(self, x):

def checkpoint_filter_fn(
state_dict,
model
model,
):
""" Remap SAM checkpoints -> timm """
sam_checkpoint = 'image_encoder.patch_embed.proj.weight' in state_dict
out_dict = {}
for k, v in state_dict.items():
if 'image_encoder.' in k:
new_k = k.replace('image_encoder.', '')
new_k = new_k.replace('mlp.lin', 'mlp.fc')
out_dict[new_k] = v
if k.startswith('image_encoder.'):
k = k[14:]
k = k.replace('mlp.lin', 'mlp.fc')
else:
if sam_checkpoint:
continue
out_dict[k] = v
return out_dict


Expand All @@ -535,19 +535,19 @@ def _cfg(url='', **kwargs):
# Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only)
'samvit_base_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
# hf_hub_id='timm/',
hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0),
'samvit_large_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
# hf_hub_id='timm/',
hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0),
'samvit_huge_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
# hf_hub_id='timm/',
hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0),
Expand Down

0 comments on commit e9373b1

Please sign in to comment.