-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
executable file
·72 lines (61 loc) · 2.21 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python
import numpy as np
import pandas as pd
from argparse import ArgumentParser
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.cross_validation import cross_val_score, KFold
import cPickle as pickle
def main():
args = parse_args()
modeler = Modeler(args)
if args.mode == 'train':
modeler.train()
elif args.mode == 'evaluate':
modeler.evaluate()
elif args.mode == 'test':
modeler.test()
def parse_args():
description = 'train microwave swe -> reconstructed swe model'
parser = ArgumentParser(description=description)
parser.add_argument('infile', help='location of csv file')
parser.add_argument('target', help='name of the target variable')
parser.add_argument('mode', help='train, evaluate, or test')
parser.add_argument('--modelfile', help='location to write model description')
return parser.parse_args()
class Modeler(object):
def __init__(self, args):
self.modelfile = args.modelfile
data = pd.read_csv(args.infile)
target = args.target
data.dropna(inplace=True)
self.X = data[[col for col in data.columns if col != target]]
self.y = data[target]
def evaluate(self):
model = self.get_model()
scores = cross_val_score( model, self.X, self.y,
cv=KFold(len(self.y), shuffle=True),
scoring='r2')
model.fit(self.X, self.y)
#print "coefficients:"
#print model.intercept_
#print model.coef_
print "score:", model.score(self.X, self.y)
print "R^2 CV scores:"
print scores
def train(self):
model = self.get_model()
model.fit(self.X, self.y)
pickle.dump(model, open(self.modelfile, 'w'), -1)
def test(self):
model = pickle.load(open(self.modelfile))
print "score:", model.score(self.X, self.y)
@staticmethod
def get_model():
return RandomForestRegressor(min_samples_split=100, n_jobs=-1)
#@staticmethod
#def get_model():
# return LinearRegression(normalize=True)
if __name__ == '__main__':
main()