-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
executable file
·65 lines (57 loc) · 2.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import numpy as np
import random
from model.image_cnn import CNN
from trainer import *
# from utils import config
import utils
from sklearn import metrics
import torch.nn as nn
import os
from matplotlib import pyplot as plt
from IPython.display import clear_output
import importlib
from sklearn.model_selection import ParameterGrid
import collections
import yaml
from experiment import Experiment
# from apex import amp, fp16_utils
import re
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument('--model_name', type=str, required=True)
parser.add_argument('--model_type', type=str, required=True)
parser.add_argument('--gpu', type=str, required=True, default = "0,1")
parser.add_argument('--budget', type=int, required=False, default = 50)
parser.add_argument('--repeats', type=int, required=False, default = 3)
parser.add_argument('--save_every', type=int, required=False, default = None)
parser.add_argument('--save_best_num', type=int, required=False, default = 1)
parser.add_argument('--eval_train', type=int, required=False, default = True)
parser.add_argument('--optimizer', type=str, required=False, default = "sgd")
args = parser.parse_args()
model_name = args.model_name
model_type = args.model_type
budget = args.budget
repeats = args.repeats
save_every = args.save_every
eval_train = args.eval_train
optimizer = args.optimizer.lower()
save_best_num = args.save_best_num
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
def main():
config_str = model_type + "." + model_name
with open('hyperparameters.yaml') as f:
hyperparameters = yaml.safe_load(f)
param_grid = hyperparameters[model_name]
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
savename = config(config_str + ".checkpoint")
if not os.path.exists(savename):
os.makedirs(savename)
if not os.path.exists('{}/log/'.format(savename)):
os.makedirs('{}/log/'.format(savename))
exp = Experiment(optimizer,device, config_str, model_type, model_name, param_grid, save_best_num,savename, budget, repeats, eval_train = eval_train, save_every = save_every)
df_search = exp.run()
df_search.to_csv('{}/log/df_search.csv'.format(savename), index=False)
if __name__ == '__main__':
main()