-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconfig.toml
47 lines (43 loc) · 1.08 KB
/
config.toml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
[PSO]
random_radius = 5
population_number = 5
iter_number = 30
padding_step = 2
[DATAEXTRACT]
data_origin_path = 'data/origin'
data_square_path = 'data/square'
data_circle_path = 'data/circle'
data_split_train_path = 'data/train'
data_split_valid_path = 'data/valid'
data_split_origin_path = 'data/square'
# train_ration可以小很多,如果是静态构建的数据集训练的话
train_ratio = 0.8
clear_origin = false
# 224 224 before
img_size = [128, 128]
[MODEL]
# resnet18, resnet34, resnet50, resnet101, resnet152
model_type = 'resnet18'
feature_dim = 128
pretrained = true
[TRAIN]
save_weights = "runs"
train_path = 'data/train'
valid_path = 'data/valid'
device = 'cuda:0'
batch_size = 16
back_true = 0.5
works = 1
# 支持的损失函数:CosineSimilarityLoss,ClassFiyOneLoss,
#ClassFiyTwoLoss,CosineMarginOneLoss,CosineMarginTwoLoss,PalmCombinedLoss(当前项目默认的损失函数)
loss = 'PalmCombinedLoss'
similarity_threshold = 0.75
loss_margin = 0.0
shuffle = true
lr = 0.001
epochs = 200
val_interval = 10
log_interval = 200
save_epoch = 10
[DETECT]
device = 'cuda:0'