Skip to content

Commit

Permalink
fix loading pretrained weight for samvit
Browse files Browse the repository at this point in the history
  • Loading branch information
seefun committed May 18, 2023
1 parent 15de561 commit c1c6eeb
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions timm/models/vision_transformer_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,11 @@ def checkpoint_filter_fn(
""" Remap SAM checkpoints -> timm """
out_dict = {}
for k, v in state_dict.items():
if 'image_encoder.' in k:
new_k = k.replace('image_encoder.', '')
new_k = new_k.replace('mlp.lin', 'mlp.fc')
out_dict[new_k] = v
return state_dict
if 'image_encoder.' in k:
new_k = k.replace('image_encoder.', '')
new_k = new_k.replace('mlp.lin', 'mlp.fc')
out_dict[new_k] = v
return out_dict


def _cfg(url='', **kwargs):
Expand All @@ -535,19 +535,19 @@ def _cfg(url='', **kwargs):
# Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only)
'samvit_base_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
hf_hub_id='timm/',
# hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0),
'samvit_large_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
hf_hub_id='timm/',
# hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0),
'samvit_huge_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
hf_hub_id='timm/',
# hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0),
Expand Down

0 comments on commit c1c6eeb

Please sign in to comment.