-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.sh
43 lines (42 loc) · 1.14 KB
/
train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# -------- architecture -------------------------
# arch=OMPa
# arch=OMPb
arch=OMPc
# -------- model:vgg16/resnet18 -----------------
# model=vgg11
# model=vgg13
# model=vgg16
# model=vgg19
# model=resnet20
# model=resnet32
model=modela
# -------- hyper-parameters ---------------------
lamb=0.1
num_paths=10
# -------- CIFAR10 ------------------------------
# dataset=CIFAR10
# data_dir='/media/Disk1/KunFang/data/CIFAR10/'
# -------- CIFAR100 -----------------------------
# dataset=CIFAR100
# data_dir='/media/Disk1/KunFang/data/CIFAR100/'
# -------- STL10 --------------------------------
dataset=STL10
data_dir='/media/Disk1/KunFang/data/STL10/'
# -------- model directory ----------------------
model_dir='./save/'
# -----------------------------------------------
gpu_id=1
# -----------------------------------------------
# adv_train=False
adv_train=True
# -----------------------------------------------
python train.py \
--arch ${arch} \
--model ${model} \
--lamb ${lamb} \
--num_paths ${num_paths} \
--dataset ${dataset} \
--data_dir ${data_dir} \
--model_dir ${model_dir} \
--gpu_id ${gpu_id} \
--adv_train ${adv_train}