-
Notifications
You must be signed in to change notification settings - Fork 2
/
piip_3branch_tsb_368-192-128_cls_token_deit1.py
72 lines (69 loc) · 2.25 KB
/
piip_3branch_tsb_368-192-128_cls_token_deit1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
model = dict(
backbone=dict(
n_points=4,
deform_num_heads=16,
cffn_ratio=0.25,
deform_ratio=0.5,
with_cffn=True,
interact_attn_type='deform',
interaction_drop_path_rate=0.4,
separate_head=True,
branch1=dict(
img_size=128,
patch_size=16,
pretrain_img_size=224,
pretrain_patch_size=16,
depth=12,
embed_dim=768,
num_heads=12,
mlp_ratio=4,
init_scale=1.0,
qkv_bias=True,
drop_rate=0.0,
drop_path_rate=0.2,
interaction_indexes=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11]],
use_cls_token=True,
use_flash_attn=True,
with_cp=True,
pretrained="pretrained/deit_base_patch16_224-b5f2ef4d.pth",
),
branch2=dict(
img_size=192,
patch_size=16,
pretrain_img_size=224,
pretrain_patch_size=16,
depth=12,
embed_dim=384,
num_heads=6,
mlp_ratio=4,
init_scale=1.0,
qkv_bias=True,
drop_rate=0.0,
drop_path_rate=0.05,
interaction_indexes=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11]],
use_cls_token=True,
use_flash_attn=True,
with_cp=True,
pretrained="pretrained/deit_small_patch16_224-cd65a155.pth",
),
branch3=dict(
img_size=368,
patch_size=16,
pretrain_img_size=224,
pretrain_patch_size=16,
depth=12,
embed_dim=192,
num_heads=3,
mlp_ratio=4,
init_scale=1.0,
qkv_bias=True,
drop_rate=0.0,
drop_path_rate=0.05,
interaction_indexes=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11]],
use_cls_token=True,
use_flash_attn=True,
with_cp=True,
pretrained="pretrained/deit_tiny_patch16_224-a1311bcf.pth",
),
),
)