From 15de561f2cc01997b248a44bba05668e41228930 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E6=9B=A6?= Date: Wed, 17 May 2023 12:51:12 +0800 Subject: [PATCH] fix unit test for samvit --- tests/test_models.py | 2 +- timm/models/vision_transformer_sam.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 247415b04d..d8ac8d6438 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -41,7 +41,7 @@ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'eva_*', 'flexivit*', 'eva02*' + 'eva_*', 'flexivit*', 'eva02*', 'samvit_*' ] NUM_NON_STD = len(NON_STD_FILTERS) diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 98c0096db1..9a9e74cc1f 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -303,7 +303,7 @@ def add_decomposed_rel_pos( class VisionTransformerSAM(nn.Module): - """ Vision Transformer for vitsam or SAM + """ Vision Transformer for Segment-Anything Model(SAM) A PyTorch impl of : `Exploring Plain Vision Transformer Backbones for Object Detection` or `Segment Anything Model (SAM)` - https://arxiv.org/abs/2010.11929 @@ -533,19 +533,19 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ # Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only) - 'vitsam_base_patch16.sa1b': _cfg( + 'samvit_base_patch16.sa1b': _cfg( url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth', hf_hub_id='timm/', license='apache-2.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 1024, 1024), crop_pct=1.0), - 'vitsam_large_patch16.sa1b': _cfg( + 'samvit_large_patch16.sa1b': _cfg( url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', hf_hub_id='timm/', license='apache-2.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 1024, 1024), crop_pct=1.0), - 'vitsam_huge_patch16.sa1b': _cfg( + 'samvit_huge_patch16.sa1b': _cfg( url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', hf_hub_id='timm/', license='apache-2.0', @@ -569,7 +569,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): @register_model -def vitsam_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM: +def samvit_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM: """ ViT-B/16 for Segment-Anything """ model_args = dict( @@ -577,12 +577,12 @@ def vitsam_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM: window_size=14, use_rel_pos=True, img_size=1024, ) model = _create_vision_transformer( - 'vitsam_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) + 'samvit_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vitsam_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM: +def samvit_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM: """ ViT-L/16 for Segment-Anything """ model_args = dict( @@ -590,12 +590,12 @@ def vitsam_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM: window_size=14, use_rel_pos=True, img_size=1024, ) model = _create_vision_transformer( - 'vitsam_large_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) + 'samvit_large_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vitsam_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM: +def samvit_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM: """ ViT-H/16 for Segment-Anything """ model_args = dict( @@ -603,7 +603,7 @@ def vitsam_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM: window_size=14, use_rel_pos=True, img_size=1024, ) model = _create_vision_transformer( - 'vitsam_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) + 'samvit_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) return model # TODO: