Skip to content

Commit

Permalink
fix unit test for samvit
Browse files Browse the repository at this point in the history
  • Loading branch information
seefun committed May 17, 2023
1 parent ea1f52d commit 15de561
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*'
'eva_*', 'flexivit*', 'eva02*', 'samvit_*'
]
NUM_NON_STD = len(NON_STD_FILTERS)

Expand Down
20 changes: 10 additions & 10 deletions timm/models/vision_transformer_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def add_decomposed_rel_pos(


class VisionTransformerSAM(nn.Module):
""" Vision Transformer for vitsam or SAM
""" Vision Transformer for Segment-Anything Model(SAM)
A PyTorch impl of : `Exploring Plain Vision Transformer Backbones for Object Detection` or `Segment Anything Model (SAM)`
- https://arxiv.org/abs/2010.11929
Expand Down Expand Up @@ -533,19 +533,19 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({

# Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only)
'vitsam_base_patch16.sa1b': _cfg(
'samvit_base_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0),
'vitsam_large_patch16.sa1b': _cfg(
'samvit_large_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0),
'vitsam_huge_patch16.sa1b': _cfg(
'samvit_huge_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
hf_hub_id='timm/',
license='apache-2.0',
Expand All @@ -569,41 +569,41 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):


@register_model
def vitsam_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
def samvit_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
""" ViT-B/16 for Segment-Anything
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, global_attn_indexes=[2, 5, 8, 11],
window_size=14, use_rel_pos=True, img_size=1024,
)
model = _create_vision_transformer(
'vitsam_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
'samvit_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vitsam_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
def samvit_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
""" ViT-L/16 for Segment-Anything
"""
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, global_attn_indexes=[5, 11, 17, 23],
window_size=14, use_rel_pos=True, img_size=1024,
)
model = _create_vision_transformer(
'vitsam_large_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
'samvit_large_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vitsam_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
def samvit_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
""" ViT-H/16 for Segment-Anything
"""
model_args = dict(
patch_size=16, embed_dim=1280, depth=32, num_heads=16, global_attn_indexes=[7, 15, 23, 31],
window_size=14, use_rel_pos=True, img_size=1024,
)
model = _create_vision_transformer(
'vitsam_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
'samvit_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model

# TODO:
Expand Down

0 comments on commit 15de561

Please sign in to comment.