forked from NVIDIA/gbm-bench
-
Notifications
You must be signed in to change notification settings - Fork 0
/
runme.py
128 lines (112 loc) · 5.03 KB
/
runme.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import sys
import argparse
import json
import ast
import psutil
import algorithms
from metrics import get_metrics
import gc
from datasets import (prepare_airline, prepare_airline_regression, prepare_bosch,
prepare_fraud, prepare_higgs, prepare_year, prepare_epsilon,
prepare_covtype)
def print_sys_info():
try:
import xgboost # pylint: disable=import-outside-toplevel
print("Xgboost : %s" % xgboost.__version__)
except ImportError:
pass
try:
import lightgbm # pylint: disable=import-outside-toplevel
print("LightGBM: %s" % lightgbm.__version__)
except (ImportError, OSError):
pass
try:
import catboost # pylint: disable=import-outside-toplevel
print("Catboost: %s" % catboost.__version__)
except ImportError:
pass
print("System : %s" % sys.version)
print("#CPUs : %d" % psutil.cpu_count(logical=False))
def prepare_dataset(dataset_folder, dataset_parameters):
if not os.path.exists(dataset_folder):
os.makedirs(dataset_folder)
prepare_function = globals()["prepare_" + dataset_parameters['dataset_name']]
return prepare_function(dataset_folder, dataset_parameters)
def parse_args():
parser = argparse.ArgumentParser(
description="Benchmark xgboost/lightgbm/catboost on real datasets")
parser.add_argument("-root", default="/opt/gbm-datasets",
type=str, help="The root datasets folder")
parser.add_argument("-input", required=True, help='JSON file that contains experiment parameters')
parser.add_argument("-output", default=sys.path[0] + "/results.json", type=str,
help="Output json file with runtime/accuracy stats")
parser.add_argument("-verbose", action="store_true", help="Produce verbose output")
args = parser.parse_args()
return args
# benchmarks a single dataset
def benchmark(algo, dataset_dir, dataset_parameters, algorithm_parameters):
data = prepare_dataset(dataset_dir, dataset_parameters)
results = {}
runner = algorithms.Algorithm.create(algo)
with runner:
train_time = runner.fit(data, algorithm_parameters)
pred = runner.test(data)
result = {
"train_time" : train_time,
"accuracy": get_metrics(data, pred)
}
del data
gc.collect()
return result
def main():
args = parse_args()
print_sys_info()
with open(args.input) as fp:
experiments = json.load(fp)
results = []
for exp in experiments['experiments']:
output = exp.copy()
dataset_parameters = exp['dataset_parameters']
if not 'nrows' in dataset_parameters.keys():
dataset_parameters['nrows'] = None
algorithm_parameters = exp['algorithm_parameters']
dataset_dir = os.path.join(
args.root, dataset_parameters['dataset_name'])
res = benchmark(exp['algo'], dataset_dir, dataset_parameters,
algorithm_parameters)
output.update({'result' : res})
results.append(output)
# print(json.dumps({ 'experiments' : results }, indent = 2, sort_keys = True))
results_str = json.dumps({ 'experiments' : results }, indent=2, sort_keys=True)
with open(args.output, "w") as fp:
fp.write(results_str + "\n")
print("Results written to file '%s'" % args.output)
if __name__ == "__main__":
main()