forked from chong-z/tree-ensemble-attack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bounding_box.h
170 lines (120 loc) · 4.64 KB
/
bounding_box.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#pragma once
#include <map>
#include <set>
#include <string>
#include <boost/container/flat_map.hpp>
#include "interval.h"
#include "utility.h"
namespace cz {
class DecisionTree;
class DecisionForest;
class OrderedBoxes;
class BoundingBox {
public:
using IntervalsType = boost::container::flat_map<int, Interval>;
BoundingBox();
explicit BoundingBox(const DecisionTree*);
BoundingBox(BoundingBox&&);
BoundingBox(const BoundingBox&) = default;
~BoundingBox();
void IntersectFeature(int feature_id, const Interval&);
void Intersect(const BoundingBox&);
bool Overlaps(const BoundingBox&);
bool HasUpper(int feature_id) const;
double Upper(int feature_id) const;
bool HasLower(int feature_id) const;
double Lower(int feature_id) const;
bool Contains(const Point&) const;
Patch ClosestPatchTo(const Point& point) const;
double NormTo(const Point& point, int norm_type) const;
IntervalsType& Intervals();
const IntervalsType& Intervals() const;
Interval& operator[](int feature_id);
const Interval& operator[](int feature_id) const;
Interval GetOrEmpty(int feature_id) const;
bool operator==(const BoundingBox&) const;
const DecisionTree* OwnerTree() const;
double Label() const;
void SetLabel(double label);
std::string ToDebugString() const;
void Clear();
private:
// feature_id -> Interval.
IntervalsType intervals_;
const DecisionTree* owner_tree_ = nullptr;
// Prediction label for this box.
double label_ = 0;
};
class LayeredBoundingBox {
public:
LayeredBoundingBox(const DecisionForest* owner_forest,
int num_class,
int max_feature_id,
int class1,
int class2);
~LayeredBoundingBox();
void AddBox(const BoundingBox*);
int PredictionLabel(const std::vector<double>* scores = nullptr) const;
double LabelScore(int victim_label,
const std::vector<double>* scores = nullptr) const;
const std::vector<double>& Scores() const;
std::vector<const BoundingBox*> GetEffectiveBoxesForFeature(
int feature_id,
SearchMode search_mode) const;
std::vector<FeatureDir> GetBoundedFeatures() const;
std::vector<const BoundingBox*> GetAlternativeBoxes(
const BoundingBox& target_feature_constrain,
int max_dist,
const BoundingBox* box_to_replace,
bool enable_relaxed_boundary,
const BoundingBox* hard_constrain = nullptr) const;
void FillIncompatibleBoxes(
int feature_id,
double value,
std::vector<const BoundingBox*>* incompatible_boxes) const;
// Current |location_| must also be the optimized point.
Patch StretchWithinBox(
const Patch& patch,
const Point& victim_point,
const BoundingBox* constrain_box,
const std::vector<const BoundingBox*>& incompatible_boxes) const;
std::vector<const BoundingBox*> GetNewBoxes(
const Patch& patch,
const std::vector<const BoundingBox*>& incompatible_boxes) const;
std::vector<double> GetNewScores(
const std::vector<const BoundingBox*>& incompatible_boxes,
const std::vector<const BoundingBox*>& new_boxes) const;
void TightenPoint(Point* new_adv,
const std::vector<const BoundingBox*>& new_boxes) const;
void ShiftPoint(const Point& point);
// |ShiftByPatch| is more efficient.
void ShiftByPatch(const Patch& patch);
void ShiftByDirection(const Direction& dir);
const BoundingBox* GetCachedIntersection() const;
std::vector<BoundingBox> GetIndenpendentBoundingBoxes() const;
void SetInitialLocation(const Point& initial_location);
const Point& Location() const;
void VerifyCachedIntersectionForTesting() const;
size_t Hash() const;
const BoundingBox* GetBoxForTree(const DecisionTree*) const;
std::vector<const BoundingBox*> GetBoxForAllTree() const;
bool CheckScoresForTesting(const Patch& patch,
const std::vector<double>& scores) const;
void AssertTightForTesting(const Point& victim_point) const;
private:
// Won't update |ordered_boxes_|.
void RemoveBox(const BoundingBox* box);
Point location_;
int class1_ = -1;
int class2_ = -1;
std::hash<const BoundingBox*> ptr_hasher_;
size_t hash_ = 0;
std::unique_ptr<OrderedBoxes> ordered_boxes_;
// One box per tree.
std::map<const DecisionTree*, const BoundingBox*> boxes_;
// Scores per class.
std::vector<double> scores_;
const DecisionForest* owner_forest_ = nullptr;
DISALLOW_COPY_AND_ASSIGN(LayeredBoundingBox);
};
} // namespace cz