-
Notifications
You must be signed in to change notification settings - Fork 1
/
resa34_tusimple.py
108 lines (92 loc) · 1.96 KB
/
resa34_tusimple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
net = dict(
type='Detector',
)
backbone = dict(
type='ResNetWrapper',
resnet='resnet34',
pretrained=True,
replace_stride_with_dilation=[False, True, True],
out_conv=True,
)
featuremap_out_channel = 128
featuremap_out_stride = 8
aggregator = dict(
type='RESA',
direction=['d', 'u', 'r', 'l'],
alpha=2.0,
iter=4,
conv_stride=9,
)
sample_y=range(710, 350, -10)
heads = dict(
type='LaneSeg',
decoder=dict(type='BUSD'),
thr=0.6,
sample_y=sample_y,
)
optimizer = dict(
type = 'SGD',
lr = 0.025,
weight_decay = 1e-4,
momentum = 0.9
)
epochs = 150
batch_size = 8
total_iter = (3616 // batch_size + 1) * epochs
import math
scheduler = dict(
type = 'LambdaLR',
lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
)
bg_weight = 0.4
img_norm = dict(
mean=[103.939, 116.779, 123.68],
std=[1., 1., 1.]
)
img_height = 368
img_width = 640
cut_height = 160
ori_img_h = 720
ori_img_w = 1280
train_process = [
dict(type='RandomRotation'),
dict(type='RandomHorizontalFlip'),
dict(type='Resize', size=(img_width, img_height)),
dict(type='Normalize', img_norm=img_norm),
dict(type='ToTensor'),
]
val_process = [
dict(type='Resize', size=(img_width, img_height)),
dict(type='Normalize', img_norm=img_norm),
dict(type='ToTensor', keys=['img']),
]
dataset_path = './data/tusimple'
dataset = dict(
train=dict(
type='TuSimple',
data_root=dataset_path,
split='trainval',
processes=train_process,
),
val=dict(
type='TuSimple',
data_root=dataset_path,
split='test',
processes=val_process,
),
test=dict(
type='TuSimple',
data_root=dataset_path,
split='test',
processes=val_process,
)
)
batch_size = 8
workers = 12
num_classes = 6 + 1
ignore_label = 255
log_interval = 100
eval_ep = 1
save_ep = epochs
test_json_file='data/tusimple/test_label.json'
lr_update_by_epoch = False