Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dramatically reduce encoding times for some equality formulas #121

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 138 additions & 16 deletions cozy/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ def __init__(self, z3ctx, z3solver, stop_callback : Callable[[], bool]):
self.int_one = z3.IntVal(1, self.ctx)
self.true = z3.BoolVal(True, self.ctx)
self.false = z3.BoolVal(False, self.ctx)

# tuple_sorts maps TTuple types to (z3_sort, constructor, projections)
# produced by z3.TupleSort. Using this cache ensures that we never
# construct more than one fresh Z3 tuple sort for any possible tuple
# type. (Z3 tuple sorts are identified by string names, not by the
# types that go into the tuple!) See self.symbolic_sort(...).
self.tuple_sorts = {}

self.stop_callback = stop_callback
assert to_bool(self.true) is True
assert to_bool(self.false) is False
Expand Down Expand Up @@ -355,27 +363,45 @@ def eq(self, t, e1, e2, deep=False):
return res[0][0]

elif isinstance(t, TBag) or isinstance(t, TSet):
elem_type = t.elem_type
lhs_mask, lhs_elems = e1
rhs_mask, rhs_elems = e2

# n = max(len(lhs_elems), len(rhs_elems))
try:
s1 = self.symbolic_form(t, e1)
s2 = self.symbolic_form(t, e2)
return s1 == s2
except NotImplementedError as e:
# If we can't use the efficient symbolic_form equality
# encoding, then fall back to very slow old-style equality
# encoding.

# lengths equal... might not be necessary
e1len = self.len_of(e1)
e2len = self.len_of(e2)
conds = []
conds.append(e1len == e2len)
# NOTE 2020/8/8: There are a few types for which symbolic_form
# is not implemented (e.g. maps, lists). We can get rid of
# this branch someday when we've figured out how to construct
# a fully-symbolic encoding for every type.

lhs_counts = [ (x, self.count_in(elem_type, e1, x, deep=deep)) for x in lhs_elems ]
for x, count in lhs_counts:
conds.append(count == self.count_in(elem_type, e2, x, deep=deep))
print("WARNING: falling back to slow bag/set equality encoding (exception={!r})".format(e))

rhs_counts = [ (x, self.count_in(elem_type, e1, x, deep=deep)) for x in rhs_elems ]
for x, count in rhs_counts:
conds.append(count == self.count_in(elem_type, e1, x, deep=deep))
elem_type = t.elem_type
lhs_mask, lhs_elems = e1
rhs_mask, rhs_elems = e2

# n = max(len(lhs_elems), len(rhs_elems))

# lengths equal... might not be necessary
e1len = self.len_of(e1)
e2len = self.len_of(e2)
conds = []
conds.append(e1len == e2len)

lhs_counts = [ (x, self.count_in(elem_type, e1, x, deep=deep)) for x in lhs_elems ]
for x, count in lhs_counts:
conds.append(count == self.count_in(elem_type, e2, x, deep=deep))

rhs_counts = [ (x, self.count_in(elem_type, e1, x, deep=deep)) for x in rhs_elems ]
for x, count in rhs_counts:
conds.append(count == self.count_in(elem_type, e1, x, deep=deep))

return self.all(*conds)

return self.all(*conds)
elif isinstance(t, TMap):
conds = [self.eq(t.v, e1["default"], e2["default"], deep=deep)]
def map_keys(m):
Expand Down Expand Up @@ -409,6 +435,102 @@ def map_keys(m):
return self.all(*conds)
else:
raise NotImplementedError(t)
def symbolic_sort(self, t):
"""
t - a type

returns the Z3 sort for the encoding produced by `self.symbolic_form`
for objects of the given type.

Also, this function initializes self.tuple_sorts[*] for any TTuple type
contained in `t`.
"""
if isinstance(t, TInt):
return z3.IntSort(self.ctx)
elif isinstance(t, TLong):
return z3.IntSort(self.ctx)
elif isinstance(t, TBool):
return z3.BoolSort(self.ctx)
elif isinstance(t, TString):
return z3.IntSort(self.ctx)
elif isinstance(t, THandle):
return z3.IntSort(self.ctx)
elif isinstance(t, TNative):
return z3.IntSort(self.ctx)
elif isinstance(t, TBag):
return z3.ArraySort(
self.symbolic_sort(t.elem_type),
self.symbolic_sort(TInt()))
elif isinstance(t, TSet):
return z3.ArraySort(
self.symbolic_sort(t.elem_type),
self.symbolic_sort(TBool()))
elif isinstance(t, TTuple):
res = self.tuple_sorts.get(t)
if res is None:
res = self.tuple_sorts[t] = z3.TupleSort(fresh_name("CustomTupleSort"), [self.symbolic_sort(t) for t in t.ts], self.ctx)
sort, constructor, projections = res
return sort
elif isinstance(t, TRecord):
return self.symbolic_sort(TTuple(tuple(ft for field, ft in t.fields)))
else:
raise NotImplementedError(t)
def symbolic_form(self, t, x):
"""
t - type of x
x - a symbolic value

returns a Z3 AST representing the object

NOTE: the returned AST is not suitable for "deep" equality comparisons.
The docstring on `compare_values` in value_types.py describes deep
equality in detail.

NOTE: the implementation of this function is Z3-specific, and is not
easily portable to other solvers. It relies on Z3's "constant array"
primitive, which is not part of the SMT-LIB standard.
"""
if decideable(t):
return x
elif isinstance(t, THandle):
return x[0] # address
elif isinstance(t, TNative):
return x
elif isinstance(t, TBag):
elem_type = t.elem_type
symbolic_elem_type = self.symbolic_sort(elem_type)
result = z3.K(symbolic_elem_type, self.int_zero)
masks, elems = x
for mask, elem in zip(masks, elems):
symbolic_elem = self.symbolic_form(elem_type, elem)
result = z3.Store(
result,
symbolic_elem,
z3.Select(result, symbolic_elem) + z3.If(mask, self.int_one, self.int_zero))
return result
elif isinstance(t, TSet):
elem_type = t.elem_type
symbolic_elem_type = self.symbolic_sort(elem_type)
result = z3.K(symbolic_elem_type, self.false)
masks, elems = x
for mask, elem in zip(masks, elems):
symbolic_elem = self.symbolic_form(elem_type, elem)
result = z3.Store(
result,
symbolic_elem,
self.any(z3.Select(result, symbolic_elem), mask))
return result
elif isinstance(t, TTuple):
# Calling symbolic_sort initializes self.tuple_sorts[t].
self.symbolic_sort(t)
sort, constructor, projections = self.tuple_sorts[t]
return constructor(*[self.symbolic_form(tt, elem) for tt, elem in zip(t.ts, x)])
elif isinstance(t, TRecord):
return self.symbolic_form(
TTuple(tuple(ft for field, ft in t.fields)),
tuple(x[field] for field, ft in t.fields))
else:
raise NotImplementedError(t)
def count_in(self, t, bag, x, deep=False):
"""
t - type of elems in bag
Expand Down
9 changes: 9 additions & 0 deletions tests/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,12 @@ def test_regression34(self):
vars=OrderedSet([EVar('issues').with_type(TBag(THandle('Issue', TRecord((('id', TInt()), ('author_id', TInt()), ('project', THandle('Project', TRecord((('id', TInt()), ('status', TInt()), ('modules', TBag(THandle('ProjectModule', TRecord((('id', TInt()), ('name', TString())))))))))), ('statuses', TBag(THandle('IssueStatus', TRecord((('id', TInt()), ('is_closed', TBool())))))), ('assigned_to', TInt())))))), EVar('p1').with_type(TInt())]),
collection_depth=4,
validate_model=True)

def test_flatmap_equality_encoding_speed(self):
e = EBinOp(EFlatMap(EVar('Rs').with_type(TBag(TRecord((('A', TInt()), ('B', TString()))))), ELambda(EVar('r').with_type(TRecord((('A', TInt()), ('B', TString())))), ELet(EGetField(EVar('r').with_type(TRecord((('A', TInt()), ('B', TString())))), 'A').with_type(TInt()), ELambda(EVar('_tmp1536').with_type(TInt()), EFlatMap(EVar('Ss').with_type(TBag(TRecord((('B', TString()), ('C', TInt()))))), ELambda(EVar('s').with_type(TRecord((('B', TString()), ('C', TInt())))), EFlatMap(EVar('Qs').with_type(TBag(TRecord((('B', TString()), ('C', TInt()))))), ELambda(EVar('q').with_type(TRecord((('B', TString()), ('C', TInt())))), ELet(EGetField(EVar('q').with_type(TRecord((('B', TString()), ('C', TInt())))), 'B').with_type(TString()), ELambda(EVar('_tmp1533').with_type(TString()), EMap(EFilter(EVar('Ws').with_type(TBag(TRecord((('B', TString()), ('C', TInt()))))), ELambda(EVar('w').with_type(TRecord((('B', TString()), ('C', TInt())))), EBinOp(EBinOp(EVar('_tmp1533').with_type(TString()), '==', EGetField(EVar('w').with_type(TRecord((('B', TString()), ('C', TInt())))), 'B').with_type(TString())).with_type(TBool()), 'and', EBinOp(EVar('_tmp1536').with_type(TInt()), '==', ENum(15).with_type(TInt())).with_type(TBool())).with_type(TBool()))).with_type(TBag(TRecord((('B', TString()), ('C', TInt()))))), ELambda(EVar('w').with_type(TRecord((('B', TString()), ('C', TInt())))), ETuple((EVar('_tmp1536').with_type(TInt()), EGetField(EVar('s').with_type(TRecord((('B', TString()), ('C', TInt())))), 'C').with_type(TInt()), EVar('_tmp1533').with_type(TString()), EGetField(EVar('w').with_type(TRecord((('B', TString()), ('C', TInt())))), 'C').with_type(TInt()))).with_type(TTuple((TInt(), TInt(), TString(), TInt()))))).with_type(TBag(TTuple((TInt(), TInt(), TString(), TInt())))))).with_type(TBag(TTuple((TInt(), TInt(), TString(), TInt())))))).with_type(TBag(TTuple((TInt(), TInt(), TString(), TInt())))))).with_type(TBag(TTuple((TInt(), TInt(), TString(), TInt())))))).with_type(TBag(TTuple((TInt(), TInt(), TString(), TInt())))))).with_type(TBag(TTuple((TInt(), TInt(), TString(), TInt())))), '==', EFlatMap(EVar('Rs').with_type(TBag(TRecord((('A', TInt()), ('B', TString()))))), ELambda(EVar('r').with_type(TRecord((('A', TInt()), ('B', TString())))), EFlatMap(EVar('Ss').with_type(TBag(TRecord((('B', TString()), ('C', TInt()))))), ELambda(EVar('s').with_type(TRecord((('B', TString()), ('C', TInt())))), EFlatMap(EVar('Qs').with_type(TBag(TRecord((('B', TString()), ('C', TInt()))))), ELambda(EVar('q').with_type(TRecord((('B', TString()), ('C', TInt())))), EMap(EVar('Ws').with_type(TBag(TRecord((('B', TString()), ('C', TInt()))))), ELambda(EVar('w').with_type(TRecord((('B', TString()), ('C', TInt())))), ETuple((EGetField(EVar('r').with_type(TRecord((('A', TInt()), ('B', TString())))), 'A').with_type(TInt()), EGetField(EVar('s').with_type(TRecord((('B', TString()), ('C', TInt())))), 'C').with_type(TInt()), EGetField(EVar('q').with_type(TRecord((('B', TString()), ('C', TInt())))), 'B').with_type(TString()), EGetField(EVar('w').with_type(TRecord((('B', TString()), ('C', TInt())))), 'C').with_type(TInt()))).with_type(TTuple((TInt(), TInt(), TString(), TInt()))))).with_type(TBag(TTuple((TInt(), TInt(), TString(), TInt())))))).with_type(TBag(TTuple((TInt(), TInt(), TString(), TInt())))))).with_type(TBag(TTuple((TInt(), TInt(), TString(), TInt())))))).with_type(TBag(TTuple((TInt(), TInt(), TString(), TInt()))))).with_type(TBool())
s = IncrementalSolver(validate_model=True)

# NOTE: Z3 is much faster at solving the negation of e. The main
# purpose of this test is to check encoding speed.
s.add_assumption(ENot(e))
assert s.satisfiable(ETRUE)