forked from szq0214/MEAL-V2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
opts.py
58 lines (47 loc) · 2.89 KB
/
opts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from torch.utils import data as data_utils
from models import model_factory
def add_general_flags(parser):
parser.add_argument('--save', default='checkpoints',
help="Path to the directory to save logs and "
"checkpoints.")
parser.add_argument('--gpus', '--gpu', nargs='+', default=[0], type=int,
help="The GPU(s) on which the model should run. The "
"first GPU will be the main one.")
parser.add_argument('--cpu', action='store_const', const=[],
dest='gpus', help="If set, no gpus will be used.")
def add_dataset_flags(parser):
parser.add_argument('--imagenet', required=True, help="Path to ImageNet's "
"root directory holding 'train/' and 'val/' "
"directories.")
parser.add_argument('--batch-size', default=256, help="Batch size to use "
"distributed over all GPUs.", type=int)
parser.add_argument('--num-workers', '-j', default=40, help="Number of "
"data loading processes to use for loading data and "
"transforming.", type=int)
parser.add_argument('--image-size', default=224, help="image size to train "
"input image size.", type=int)
def add_model_flags(parser):
parser.add_argument('--model', required=True, help="The model architecture "
"name.")
parser.add_argument('--student-state-file', default=None, help="Path to student model"
"state file to initialize the student model.")
parser.add_argument('--start-epoch', default=1, help="manual epoch number "
"useful on restarts.", type=int)
parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run')
def add_teacher_flags(parser):
parser.add_argument('--teacher-model', default="vgg19_bn,resnet50", help="The "
"model that will generate soft labels per crop.",
)
parser.add_argument('--teacher-state-file', default=None,
help="Path to teacher model state file.")
def add_training_flags(parser):
parser.add_argument('--lr-regime', default=None, nargs='+', type=float,
help="If set, it will override the default learning "
"rate regime of the model. Learning rate passed must "
"be as list of [start, end, lr, ...].")
parser.add_argument('--d_lr', default=1e-4, type=float,
help="The learning rate for discriminator training")
parser.add_argument('--momentum', default=0.9, type=float,
help="The momentum of the optimization.")
parser.add_argument('--weight-decay', default=0, type=float,
help="The weight decay of the optimization.")