forked from FunctionLab/ExPecto
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
97 lines (81 loc) · 4.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Training a ExPecto sequence-based expression model.
This script takes an expression profile, specified by the expression values
in the targetIndex-th column in expFile. The expression values can be
RPKM from RNA-seq. The rows
of the expFile must match with the genes or TSSes specified in
./resources/geneanno.csv.
Example:
$ python ./train.py --expFile ./resources/geneanno.exp.csv --targetIndex 1 --output model.adipose
"""
import argparse
import xgboost as xgb
import pandas as pd
import numpy as np
from scipy.stats import spearmanr
import h5py
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--targetIndex', action="store",
dest="targetIndex", type=int)
parser.add_argument('--output', action="store", dest="output")
parser.add_argument('--expFile', action="store", dest="expFile")
parser.add_argument('--inputFile', action="store",
dest="inputFile", default='./resources/Xreducedall.2002.npy')
parser.add_argument('--annoFile', action="store",
dest="annoFile", default='./resources/geneanno.csv')
parser.add_argument('--evalFile', action="store",
dest="evalFile", default='',help='specify to save holdout set predictions')
parser.add_argument('--filterStr', action="store",
dest="filterStr", type=str, default="all")
parser.add_argument('--pseudocount', action="store",
dest="pseudocount", type=float, default=0.0001)
parser.add_argument('--num_round', action="store",
dest="num_round", type=int, default=100)
parser.add_argument('--l2', action="store", dest="l2", type=float, default=100)
parser.add_argument('--l1', action="store", dest="l1", type=float, default=0)
parser.add_argument('--eta', action="store", dest="eta",
type=float, default=0.01)
parser.add_argument('--base_score', action="store",
dest="base_score", type=float, default=2)
parser.add_argument('--threads', action="store",
dest="threads", type=int, default=16)
args = parser.parse_args()
# read resources
Xreducedall = np.load(args.inputFile)
geneanno = pd.read_csv('./resources/geneanno.csv')
if args.filterStr == 'pc':
filt = np.asarray(geneanno.iloc[:, -1] == 'protein_coding')
elif args.filterStr == 'lincRNA':
filt = np.asarray(geneanno.iloc[:, -1] == 'lincRNA')
elif args.filterStr == 'all':
filt = np.asarray(geneanno.iloc[:, -1] != 'rRNA')
else:
raise ValueError('filterStr has to be one of all, pc, and lincRNA')
geneexp = pd.read_csv(args.expFile)
filt = filt * \
np.isfinite(np.asarray(
np.log(geneexp.iloc[:, args.targetIndex] + args.pseudocount)))
# training
trainind = np.asarray(geneanno['seqnames'] != 'chrX') * np.asarray(
geneanno['seqnames'] != 'chrY') * np.asarray(geneanno['seqnames'] != 'chr8')
testind = np.asarray(geneanno['seqnames'] == 'chr8')
dtrain = xgb.DMatrix(Xreducedall[trainind * filt, :])
dtest = xgb.DMatrix(Xreducedall[(testind) * filt, :])
dtrain.set_label(np.asarray(
np.log(geneexp.iloc[trainind * filt, args.targetIndex] + args.pseudocount)))
dtest.set_label(np.asarray(
np.log(geneexp.iloc[(testind) * filt, args.targetIndex] + args.pseudocount)))
param = {'booster': 'gblinear', 'base_score': args.base_score, 'alpha': 0,
'lambda': args.l2, 'eta': args.eta, 'objective': 'reg:linear',
'nthread': args.threads, "early_stopping_rounds": 10}
evallist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = args.num_round
bst = xgb.train(param, dtrain, num_round, evallist)
ypred = bst.predict(dtest)
if args.evalFile != '':
evaldf = pd.DataFrame({'pred':ypred,'target':np.asarray(
np.log(geneexp.iloc[(testind) * filt, args.targetIndex] + args.pseudocount))})
evaldf.to_csv(args.evalFile)
bst.save_model(args.output + args.filterStr + '.pseudocount' + str(args.pseudocount) + '.lambda' + str(args.l2) + '.round' +
str(args.num_round) + '.basescore' + str(args.base_score) + '.' + geneexp.columns[args.targetIndex] + '.save')
bst.dump_model(args.output + args.filterStr + '.pseudocount' + str(args.pseudocount) + '.lambda' + str(args.l2) + '.round' +
str(args.num_round) + '.basescore' + str(args.base_score) + '.' + geneexp.columns[args.targetIndex] + '.dump')