-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdecision_forest.cc
151 lines (126 loc) · 4.08 KB
/
decision_forest.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include "decision_forest.h"
#include <fstream>
#include "bounding_box.h"
#include "decision_tree.h"
#include "nlohmann/json.hpp"
using nlohmann::json;
namespace cz {
DecisionForest::DecisionForest(int num_class, int max_feature_id)
: num_class_(num_class), max_feature_id_(max_feature_id) {}
DecisionForest::~DecisionForest() {}
std::unique_ptr<DecisionForest> DecisionForest::CreateFromJson(
const std::string& path,
int num_class,
int max_feature_id) {
std::ifstream fin(path);
json forest_array;
fin >> forest_array;
auto forest = std::make_unique<DecisionForest>(num_class, max_feature_id);
int tree_count = 0;
for (const auto& tree_obj : forest_array) {
assert(tree_obj.is_object());
int class_id = (num_class == 2) ? 1 : (tree_count % num_class);
forest->AddDecisionTree(
DecisionTree::CreateFromJson(tree_obj, class_id, true));
++tree_count;
}
forest->Setup();
return std::move(forest);
}
int DecisionForest::PredictLabel(const Point& x) const {
return MaxIndex(ComputeScores(x));
}
int DecisionForest::PredictLabelBetween(const Point& x,
int class1,
int class2) const {
double score = 0;
for (const auto& tree : trees_) {
if (tree->ClassId() == class2) {
score += tree->PredictLabel(x);
} else if (tree->ClassId() == class1) {
score -= tree->PredictLabel(x);
}
}
if (score > 0)
return class2;
return class1;
}
void DecisionForest::Setup() {
ComputeBoundingBox();
ComputeFeatureSplits();
}
void DecisionForest::ComputeBoundingBox() {
assert(!has_bounding_box_);
for (auto& t : trees_)
t->ComputeBoundingBox();
has_bounding_box_ = true;
}
void DecisionForest::ComputeFeatureSplits() {
assert(!feature_splits_);
feature_splits_ =
std::make_unique<std::vector<std::vector<double>>>(max_feature_id_ + 1);
std::vector<std::set<double>> feature_splits_set(max_feature_id_ + 1);
for (const auto& t : trees_)
t->FillFeatureSplits(&feature_splits_set);
for (int i = 0; i <= max_feature_id_; ++i) {
(*feature_splits_)[i].push_back(-0.1);
(*feature_splits_)[i].insert((*feature_splits_)[i].end(),
feature_splits_set[i].begin(),
feature_splits_set[i].end());
(*feature_splits_)[i].push_back(1.1);
}
}
std::vector<double> DecisionForest::ComputeScores(const Point& x) const {
std::vector<double> scores(num_class_, 0);
assert(num_class_ >= 2);
for (const auto& tree : trees_)
scores[tree->ClassId()] += tree->PredictLabel(x);
return std::move(scores);
}
std::unique_ptr<LayeredBoundingBox> DecisionForest::GetLayeredBoundingBox(
const Point& x,
int class1,
int class2) const {
assert(has_bounding_box_);
auto box = std::make_unique<LayeredBoundingBox>(
this, num_class_, max_feature_id_, class1, class2);
box->SetInitialLocation(x);
for (const auto& t : trees_) {
if (t->ClassId() == class1 || t->ClassId() == class2 ||
(class1 == -1 && class2 == -1)) {
box->AddBox(t->GetBoundingBox(x));
}
}
return std::move(box);
}
BoundingBox DecisionForest::GetBoundingBox(const Point& x) const {
BoundingBox joint_box;
for (const auto& t : trees_) {
joint_box.Intersect(*t->GetBoundingBox(x));
}
return std::move(joint_box);
}
const std::vector<std::vector<double>>& DecisionForest::FeatureSplits() const {
return *feature_splits_;
}
int DecisionForest::HammingDistanceBetween(const Point& p1, const Point& p2, int class1, int class2) const {
int dist = 0;
for (const auto& t : trees_) {
if (t->ClassId() != class1 && t->ClassId() != class2)
continue;
if (t->GetBoundingBox(p1) != t->GetBoundingBox(p2))
++dist;
}
return dist;
}
double DecisionForest::ComputeBinaryScoreForTesting(const Point& x) const {
assert(num_class_ == 2);
return ComputeScores(x)[1];
}
int DecisionForest::NumTreesForTesting() const {
return trees_.size();
}
void DecisionForest::AddDecisionTree(std::unique_ptr<DecisionTree> tree) {
trees_.push_back(std::move(tree));
}
} // namespace cz