Skip to content

Commit 3ac7f67

Browse files
committed
[feat] Use FastCexSolver in restricted cases that are are fairly easy to solve
1 parent c70df43 commit 3ac7f67

9 files changed

+172
-87
lines changed

include/klee/ADT/DisjointSetUnion.h

+21
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,27 @@ class DisjointSetUnion {
150150
}
151151
}
152152

153+
void getAllDependentSets(ValueType value,
154+
std::vector<ref<const SetType>> &result) const {
155+
ref<const SetType> compare = new SetType(value);
156+
for (auto &r : roots) {
157+
ref<const SetType> ics = disjointSets.at(r);
158+
if (SetType::intersects(ics, compare)) {
159+
result.push_back(ics);
160+
}
161+
}
162+
}
163+
void getAllIndependentSets(ValueType value,
164+
std::vector<ref<const SetType>> &result) const {
165+
ref<const SetType> compare = new SetType(value);
166+
for (auto &r : roots) {
167+
ref<const SetType> ics = disjointSets.at(r);
168+
if (!SetType::intersects(ics, compare)) {
169+
result.push_back(ics);
170+
}
171+
}
172+
}
173+
153174
DisjointSetUnion() {}
154175

155176
DisjointSetUnion(const internalStorage_ty &is) {

include/klee/Solver/IncompleteSolver.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,18 @@ class IncompleteSolver {
5858
/// StagedSolver - Adapter class for staging an incomplete solver with
5959
/// a complete secondary solver, to form an (optimized) complete
6060
/// solver.
61+
62+
typedef std::function<bool(const Query &)> QueryPredicate;
63+
6164
class StagedSolverImpl : public SolverImpl {
6265
private:
6366
std::unique_ptr<IncompleteSolver> primary;
6467
std::unique_ptr<Solver> secondary;
68+
QueryPredicate predicate;
6569

6670
public:
6771
StagedSolverImpl(std::unique_ptr<IncompleteSolver> primary,
68-
std::unique_ptr<Solver> secondary);
72+
std::unique_ptr<Solver> secondary, QueryPredicate predicate);
6973

7074
bool computeTruth(const Query &, bool &isValid);
7175
bool computeValidity(const Query &, PartialValidity &result);

lib/ADT/SparseStorage.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void SparseStorage<unsigned char>::print(llvm::raw_ostream &os,
3535
}
3636
os << "] default: ";
3737
}
38-
os << defaultValue;
38+
os << ((unsigned)defaultValue);
3939
}
4040

4141
template <>

lib/Expr/IndependentConstraintSetUnion.cpp

+2-16
Original file line numberDiff line numberDiff line change
@@ -95,27 +95,13 @@ void IndependentConstraintSetUnion::reEvaluateConcretization(
9595
void IndependentConstraintSetUnion::getAllIndependentConstraintSets(
9696
ref<Expr> e,
9797
std::vector<ref<const IndependentConstraintSet>> &result) const {
98-
ref<const IndependentConstraintSet> compare =
99-
new IndependentConstraintSet(new ExprEitherSymcrete::left(e));
100-
for (auto &r : roots) {
101-
ref<const IndependentConstraintSet> ics = disjointSets.at(r);
102-
if (!IndependentConstraintSet::intersects(ics, compare)) {
103-
result.push_back(ics);
104-
}
105-
}
98+
getAllIndependentSets(new ExprEitherSymcrete::left(e), result);
10699
}
107100

108101
void IndependentConstraintSetUnion::getAllDependentConstraintSets(
109102
ref<Expr> e,
110103
std::vector<ref<const IndependentConstraintSet>> &result) const {
111-
ref<const IndependentConstraintSet> compare =
112-
new IndependentConstraintSet(new ExprEitherSymcrete::left(e));
113-
for (auto &r : roots) {
114-
ref<const IndependentConstraintSet> ics = disjointSets.at(r);
115-
if (IndependentConstraintSet::intersects(ics, compare)) {
116-
result.push_back(ics);
117-
}
118-
}
104+
getAllDependentSets(new ExprEitherSymcrete::left(e), result);
119105
}
120106

121107
void IndependentConstraintSetUnion::addExpr(ref<Expr> e) {

lib/Solver/FastCexSolver.cpp

+70-14
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "klee/Solver/IncompleteSolver.h"
1919
#include "klee/Support/Debug.h"
2020
#include "klee/Support/ErrorHandling.h"
21+
#include "klee/Support/OptionCategories.h"
2122

2223
#include "klee/Support/CompilerWarning.h"
2324
DISABLE_WARNING_PUSH
@@ -33,6 +34,20 @@ DISABLE_WARNING_POP
3334
#include <vector>
3435

3536
using namespace klee;
37+
using namespace llvm;
38+
39+
namespace {
40+
enum class FastCexSolverType { EQUALITY, ALL };
41+
42+
cl::opt<FastCexSolverType> FastCexFor(
43+
"fast-cex-for",
44+
cl::desc(
45+
"Specifiy a query predicate to filter queries for FastCexSolver using"),
46+
cl::values(clEnumValN(FastCexSolverType::EQUALITY, "equality",
47+
"Query with only equality expressions"),
48+
clEnumValN(FastCexSolverType::ALL, "all", "All queries")),
49+
cl::init(FastCexSolverType::EQUALITY), cl::cat(SolvingCat));
50+
} // namespace
3651

3752
// Hacker's Delight, pgs 58-63
3853
static uint64_t minOR(uint64_t a, uint64_t b, uint64_t c, uint64_t d) {
@@ -403,10 +418,11 @@ class CexPossibleEvaluator : public ExprEvaluator {
403418
ref<Expr> getInitialValue(const Array &array, unsigned index) {
404419
// If the index is out of range, we cannot assign it a value, since that
405420
// value cannot be part of the assignment.
406-
ref<ConstantExpr> constantArraySize = dyn_cast<ConstantExpr>(array.size);
421+
ref<ConstantExpr> constantArraySize =
422+
dyn_cast<ConstantExpr>(visit(array.size));
407423
if (!constantArraySize) {
408-
klee_error(
409-
"FIXME: Arrays of symbolic sizes are unsupported in FastCex\n");
424+
klee_error("FIXME: CexPossibleEvaluator: Arrays of symbolic sizes are "
425+
"unsupported in FastCex\n");
410426
std::abort();
411427
}
412428

@@ -433,11 +449,11 @@ class CexExactEvaluator : public ExprEvaluator {
433449
ref<Expr> getInitialValue(const Array &array, unsigned index) {
434450
// If the index is out of range, we cannot assign it a value, since that
435451
// value cannot be part of the assignment.
436-
ref<ConstantExpr> constantArraySize = dyn_cast<ConstantExpr>(array.size);
452+
ref<ConstantExpr> constantArraySize =
453+
dyn_cast<ConstantExpr>(visit(array.size));
437454
if (!constantArraySize) {
438-
klee_error(
439-
"FIXME: Arrays of symbolic sizes are unsupported in FastCex\n");
440-
std::abort();
455+
return ReadExpr::create(UpdateList(&array, 0),
456+
ConstantExpr::alloc(index, array.getDomain()));
441457
}
442458

443459
if (index >= constantArraySize->getZExtValue()) {
@@ -485,10 +501,11 @@ class CexData {
485501
CexObjectData &getObjectData(const Array *A) {
486502
CexObjectData *&Entry = objects[A];
487503

488-
ref<ConstantExpr> constantArraySize = dyn_cast<ConstantExpr>(A->size);
504+
ref<ConstantExpr> constantArraySize =
505+
dyn_cast<ConstantExpr>(evaluatePossible(A->size));
489506
if (!constantArraySize) {
490-
klee_error(
491-
"FIXME: Arrays of symbolic sizes are unsupported in FastCex\n");
507+
klee_error("FIXME: CexData: Arrays of symbolic sizes are unsupported in "
508+
"FastCex\n");
492509
std::abort();
493510
}
494511

@@ -529,7 +546,7 @@ class CexData {
529546
// to see if this is an initial read or not.
530547
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(re->index)) {
531548
if (ref<ConstantExpr> constantArraySize =
532-
dyn_cast<ConstantExpr>(array->size)) {
549+
dyn_cast<ConstantExpr>(evaluatePossible(array->size))) {
533550
uint64_t index = CE->getZExtValue();
534551

535552
if (index < constantArraySize->getZExtValue()) {
@@ -1171,6 +1188,7 @@ bool FastCexSolver::computeInitialValues(
11711188
const Query &query, const std::vector<const Array *> &objects,
11721189
std::vector<SparseStorage<unsigned char>> &values, bool &hasSolution) {
11731190
CexData cd;
1191+
query.dump();
11741192

11751193
bool isValid;
11761194
bool success = propagateValues(query, cd, true, isValid);
@@ -1187,7 +1205,7 @@ bool FastCexSolver::computeInitialValues(
11871205
for (unsigned i = 0; i != objects.size(); ++i) {
11881206
const Array *array = objects[i];
11891207
assert(array);
1190-
SparseStorage<unsigned char> data;
1208+
SparseStorage<unsigned char> data(0);
11911209
ref<ConstantExpr> arrayConstantSize =
11921210
dyn_cast<ConstantExpr>(cd.evaluatePossible(array->size));
11931211
assert(arrayConstantSize &&
@@ -1212,7 +1230,45 @@ bool FastCexSolver::computeInitialValues(
12121230
return true;
12131231
}
12141232

1233+
class OnlyEqualityWithConstantQueryPredicate {
1234+
public:
1235+
explicit OnlyEqualityWithConstantQueryPredicate() {}
1236+
1237+
bool operator()(const Query &query) const {
1238+
for (auto constraint : query.constraints.cs()) {
1239+
if (const EqExpr *ee = dyn_cast<EqExpr>(constraint)) {
1240+
if (!isa<ConstantExpr>(ee->left)) {
1241+
return false;
1242+
}
1243+
} else {
1244+
return false;
1245+
}
1246+
}
1247+
if (ref<EqExpr> ee = dyn_cast<EqExpr>(query.negateExpr().expr)) {
1248+
if (!isa<ConstantExpr>(ee->left)) {
1249+
return false;
1250+
}
1251+
} else {
1252+
return false;
1253+
}
1254+
return true;
1255+
}
1256+
};
1257+
1258+
class TrueQueryPredicate {
1259+
public:
1260+
explicit TrueQueryPredicate() {}
1261+
1262+
bool operator()(const Query &query) const { return true; }
1263+
};
1264+
12151265
std::unique_ptr<Solver> klee::createFastCexSolver(std::unique_ptr<Solver> s) {
1216-
return std::make_unique<Solver>(std::make_unique<StagedSolverImpl>(
1217-
std::make_unique<FastCexSolver>(), std::move(s)));
1266+
if (FastCexFor == FastCexSolverType::EQUALITY) {
1267+
return std::make_unique<Solver>(std::make_unique<StagedSolverImpl>(
1268+
std::make_unique<FastCexSolver>(), std::move(s),
1269+
OnlyEqualityWithConstantQueryPredicate()));
1270+
} else {
1271+
return std::make_unique<Solver>(std::make_unique<StagedSolverImpl>(
1272+
std::make_unique<FastCexSolver>(), std::move(s), TrueQueryPredicate()));
1273+
}
12181274
}

lib/Solver/IncompleteSolver.cpp

+68-49
Original file line numberDiff line numberDiff line change
@@ -49,60 +49,68 @@ PartialValidity IncompleteSolver::computeValidity(const Query &query) {
4949
/***/
5050

5151
StagedSolverImpl::StagedSolverImpl(std::unique_ptr<IncompleteSolver> primary,
52-
std::unique_ptr<Solver> secondary)
53-
: primary(std::move(primary)), secondary(std::move(secondary)) {}
52+
std::unique_ptr<Solver> secondary,
53+
QueryPredicate predicate)
54+
: primary(std::move(primary)), secondary(std::move(secondary)),
55+
predicate(predicate) {}
5456

5557
bool StagedSolverImpl::computeTruth(const Query &query, bool &isValid) {
56-
PartialValidity trueResult = primary->computeTruth(query);
58+
if (predicate(query)) {
59+
PartialValidity trueResult = primary->computeTruth(query);
5760

58-
if (trueResult != PValidity::None) {
59-
isValid = (trueResult == PValidity::MustBeTrue);
60-
return true;
61+
if (trueResult != PValidity::None) {
62+
isValid = (trueResult == PValidity::MustBeTrue);
63+
return true;
64+
}
6165
}
6266

6367
return secondary->impl->computeTruth(query, isValid);
6468
}
6569

6670
bool StagedSolverImpl::computeValidity(const Query &query,
6771
PartialValidity &result) {
68-
bool tmp;
69-
70-
switch (primary->computeValidity(query)) {
71-
case PValidity::MustBeTrue:
72-
result = PValidity::MustBeTrue;
73-
break;
74-
case PValidity::MustBeFalse:
75-
result = PValidity::MustBeFalse;
76-
break;
77-
case PValidity::TrueOrFalse:
78-
result = PValidity::TrueOrFalse;
79-
break;
80-
case PValidity::MayBeTrue:
81-
if (secondary->impl->computeTruth(query, tmp)) {
82-
83-
result = tmp ? PValidity::MustBeTrue : PValidity::TrueOrFalse;
84-
} else {
85-
result = PValidity::MayBeTrue;
86-
}
87-
break;
88-
case PValidity::MayBeFalse:
89-
if (secondary->impl->computeTruth(query.negateExpr(), tmp)) {
90-
result = tmp ? PValidity::MustBeFalse : PValidity::TrueOrFalse;
91-
} else {
92-
result = PValidity::MayBeFalse;
72+
if (predicate(query)) {
73+
bool tmp;
74+
75+
switch (primary->computeValidity(query)) {
76+
case PValidity::MustBeTrue:
77+
result = PValidity::MustBeTrue;
78+
break;
79+
case PValidity::MustBeFalse:
80+
result = PValidity::MustBeFalse;
81+
break;
82+
case PValidity::TrueOrFalse:
83+
result = PValidity::TrueOrFalse;
84+
break;
85+
case PValidity::MayBeTrue:
86+
if (secondary->impl->computeTruth(query, tmp)) {
87+
88+
result = tmp ? PValidity::MustBeTrue : PValidity::TrueOrFalse;
89+
} else {
90+
result = PValidity::MayBeTrue;
91+
}
92+
break;
93+
case PValidity::MayBeFalse:
94+
if (secondary->impl->computeTruth(query.negateExpr(), tmp)) {
95+
result = tmp ? PValidity::MustBeFalse : PValidity::TrueOrFalse;
96+
} else {
97+
result = PValidity::MayBeFalse;
98+
}
99+
break;
100+
default:
101+
if (!secondary->impl->computeValidity(query, result))
102+
return false;
103+
break;
93104
}
94-
break;
95-
default:
96-
if (!secondary->impl->computeValidity(query, result))
97-
return false;
98-
break;
105+
} else {
106+
return secondary->impl->computeValidity(query, result);
99107
}
100108

101109
return true;
102110
}
103111

104112
bool StagedSolverImpl::computeValue(const Query &query, ref<Expr> &result) {
105-
if (primary->computeValue(query, result))
113+
if (predicate(query) && primary->computeValue(query, result))
106114
return true;
107115

108116
return secondary->impl->computeValue(query, result);
@@ -111,25 +119,28 @@ bool StagedSolverImpl::computeValue(const Query &query, ref<Expr> &result) {
111119
bool StagedSolverImpl::computeInitialValues(
112120
const Query &query, const std::vector<const Array *> &objects,
113121
std::vector<SparseStorage<unsigned char>> &values, bool &hasSolution) {
114-
if (primary->computeInitialValues(query, objects, values, hasSolution))
122+
if (predicate(query) &&
123+
primary->computeInitialValues(query, objects, values, hasSolution))
115124
return true;
116125

117126
return secondary->impl->computeInitialValues(query, objects, values,
118127
hasSolution);
119128
}
120129

121130
bool StagedSolverImpl::check(const Query &query, ref<SolverResponse> &result) {
122-
std::vector<const Array *> objects;
123-
findSymbolicObjects(query, objects);
124-
std::vector<SparseStorage<unsigned char>> values;
125-
126-
bool hasSolution;
127-
128-
bool primaryResult =
129-
primary->computeInitialValues(query, objects, values, hasSolution);
130-
if (primaryResult && hasSolution) {
131-
result = new InvalidResponse(objects, values);
132-
return true;
131+
if (predicate(query)) {
132+
std::vector<const Array *> objects;
133+
findSymbolicObjects(query, objects);
134+
std::vector<SparseStorage<unsigned char>> values;
135+
136+
bool hasSolution;
137+
138+
bool primaryResult =
139+
primary->computeInitialValues(query, objects, values, hasSolution);
140+
if (primaryResult && hasSolution) {
141+
result = new InvalidResponse(objects, values);
142+
return true;
143+
}
133144
}
134145

135146
return secondary->impl->check(query, result);
@@ -138,6 +149,14 @@ bool StagedSolverImpl::check(const Query &query, ref<SolverResponse> &result) {
138149
bool StagedSolverImpl::computeValidityCore(const Query &query,
139150
ValidityCore &validityCore,
140151
bool &isValid) {
152+
if (predicate(query)) {
153+
PartialValidity trueResult = primary->computeTruth(query);
154+
155+
if (trueResult == PValidity::MayBeFalse) {
156+
isValid = false;
157+
return true;
158+
}
159+
}
141160
return secondary->impl->computeValidityCore(query, validityCore, isValid);
142161
}
143162

0 commit comments

Comments
 (0)