diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 9a9e74cc1f..5c395d2c5f 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -512,11 +512,11 @@ def checkpoint_filter_fn( """ Remap SAM checkpoints -> timm """ out_dict = {} for k, v in state_dict.items(): - if 'image_encoder.' in k: - new_k = k.replace('image_encoder.', '') - new_k = new_k.replace('mlp.lin', 'mlp.fc') - out_dict[new_k] = v - return state_dict + if 'image_encoder.' in k: + new_k = k.replace('image_encoder.', '') + new_k = new_k.replace('mlp.lin', 'mlp.fc') + out_dict[new_k] = v + return out_dict def _cfg(url='', **kwargs): @@ -535,19 +535,19 @@ def _cfg(url='', **kwargs): # Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only) 'samvit_base_patch16.sa1b': _cfg( url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth', - hf_hub_id='timm/', + # hf_hub_id='timm/', license='apache-2.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 1024, 1024), crop_pct=1.0), 'samvit_large_patch16.sa1b': _cfg( url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', - hf_hub_id='timm/', + # hf_hub_id='timm/', license='apache-2.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 1024, 1024), crop_pct=1.0), 'samvit_huge_patch16.sa1b': _cfg( url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', - hf_hub_id='timm/', + # hf_hub_id='timm/', license='apache-2.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 1024, 1024), crop_pct=1.0),