Skip to content

Commit

Permalink
fix EHeapElems bug by combining length member with heap array
Browse files Browse the repository at this point in the history
  • Loading branch information
izgzhen committed Oct 26, 2018
1 parent cb7fd1e commit a95628b
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 49 deletions.
31 changes: 21 additions & 10 deletions cozy/codegen/cxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def visit_EListGet(self, e):
evaluation.construct_value(e.type)).with_type(e.type))).with_type(e.type))

def visit_EArrayIndexOf(self, e):
assert isinstance(e.a, EVar) # TODO: make this fast when this is false
# assert isinstance(e.a, EVar) # TODO: make this fast when this is false
it = self.fv(TNative("{}::const_iterator".format(self.visit(e.a.type, "").strip())), "cursor")
res = self.fv(INT, "index")
self.visit(seq([
Expand All @@ -374,6 +374,9 @@ def visit_EMakeMap2(self, e):
self.declare(m, e)
return m.id

def visit_EArrayList(self, e):
return "(std::vector<{}>())".format(self.visit(e.type.elem_type, ""))

def visit_EHasKey(self, e):
map = self.visit(e.map)
key = self.visit(e.key)
Expand Down Expand Up @@ -980,6 +983,12 @@ def setup_types(self, spec, state_exps, sharing):
if t not in self.types and type(t) in [THandle, TRecord, TTuple, TEnum]:
name = names.get(t, self.fn("Type"))
self.types[t] = name
# add representation type for extension data structures
h = extension_handler(type(t))
if t not in self.types and h is not None:
name = names.get(t, self.fn("Type"))
t = h.rep_type(t)
self.types[t] = name

def visit_args(self, args):
for (i, (v, t)) in enumerate(args):
Expand Down Expand Up @@ -1051,21 +1060,23 @@ def visit_Spec(self, spec : Spec, state_exps : { str : Exp }, sharing, abstract_
if isinstance(t, THandle):
# No overridden hash code! We use pointers instead.
continue
x = EVar("x").with_type(t)
if isinstance(t, TEnum):
fields = [EEnumToInt(x).with_type(INT)]
elif isinstance(t, TRecord):
fields = [EGetField(x, f).with_type(ft) for (f, ft) in t.fields]
elif isinstance(t, TTuple):
fields = [ETupleGet(x, n).with_type(tt) for (n, tt) in enumerate(t.ts)]
else:
raise NotImplementedError(t)
if not all([is_scalar(f.type) for f in fields]):
continue
self.write("struct _Hash", name, " ")
with self.block():
self.write_stmt("typedef ", spec.name, "::", name, " argument_type;")
self.write_stmt("typedef std::size_t result_type;")
self.begin_statement()
self.write("result_type operator()(const argument_type& x) const noexcept ")
x = EVar("x").with_type(t)
if isinstance(t, TEnum):
fields = [EEnumToInt(x).with_type(INT)]
elif isinstance(t, TRecord):
fields = [EGetField(x, f).with_type(ft) for (f, ft) in t.fields]
elif isinstance(t, TTuple):
fields = [ETupleGet(x, n).with_type(tt) for (n, tt) in enumerate(t.ts)]
else:
raise NotImplementedError(t)
with self.block():
self.visit(self.compute_hash(fields))
self.end_statement()
Expand Down
3 changes: 3 additions & 0 deletions cozy/codegen/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def body(x):
self.end_statement()
self.end_statement()

def visit_EArrayList(self, e):
return "(new {}[0])".format(self.visit(e.type.elem_type, ""))

def initialize_native_list(self, out):
init = "new {};\n".format(self.visit(out.type, name="()"))
return SEscape("{indent}{e} = " + init, ["e"], [out])
Expand Down
1 change: 1 addition & 0 deletions cozy/structures/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from cozy.syntax import Type, Exp, Stm

TArray = declare_case(Type, "TArray", ["elem_type"])
EArrayList = declare_case(Exp, "EArrayList", [])
EArrayCapacity = declare_case(Exp, "EArrayCapacity", ["e"])
EArrayLen = declare_case(Exp, "EArrayLen", ["e"])
EArrayGet = declare_case(Exp, "EArrayGet", ["a", "i"])
Expand Down
60 changes: 28 additions & 32 deletions cozy/structures/heaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from cozy.syntax_tools import fresh_var, pprint, mk_lambda, alpha_equivalent
from cozy.pools import Pool, RUNTIME_POOL, STATE_POOL

from .arrays import TArray, EArrayGet, EArrayIndexOf, SArrayAlloc, SEnsureCapacity, EArrayLen
from .arrays import TArray, EArrayGet, EArrayIndexOf, SArrayAlloc, SEnsureCapacity, EArrayLen, EArrayList

TMinHeap = declare_case(Type, "TMinHeap", ["elem_type", "key_type"])
TMaxHeap = declare_case(Type, "TMaxHeap", ["elem_type", "key_type"])
Expand All @@ -14,8 +14,8 @@
EMakeMaxHeap = declare_case(Exp, "EMakeMaxHeap", ["e", "key_function"])

EHeapElems = declare_case(Exp, "EHeapElems", ["e"]) # all elements
EHeapPeek = declare_case(Exp, "EHeapPeek", ["e", "heap_length"]) # look at min
EHeapPeek2 = declare_case(Exp, "EHeapPeek2", ["e", "heap_length"]) # look at 2nd min
EHeapPeek = declare_case(Exp, "EHeapPeek", ["e"]) # look at min
EHeapPeek2 = declare_case(Exp, "EHeapPeek2", ["e"]) # look at 2nd min

def to_heap(e : Exp) -> Exp:
"""Construct a heap that would be useful for evaluating `e`.
Expand Down Expand Up @@ -90,10 +90,9 @@ def enumerate(self, context, size, pool, enumerate_subexps, enumerate_lambdas):
t = e.type
if isinstance(t, TMinHeap) or isinstance(t, TMaxHeap):
elem_type = t.elem_type
n_elems = EStateVar(ELen(EHeapElems(e).with_type(TBag(elem_type)))).with_type(INT)
# yielding EHeapElems would be redundant
yield EHeapPeek (EStateVar(e).with_type(e.type), n_elems).with_type(elem_type)
yield EHeapPeek2(EStateVar(e).with_type(e.type), n_elems).with_type(elem_type)
yield EHeapPeek (EStateVar(e).with_type(e.type)).with_type(elem_type)
yield EHeapPeek2(EStateVar(e).with_type(e.type)).with_type(elem_type)

def default_value(self, t : Type, default_value) -> Exp:
"""Construct a default value of the given type.
Expand All @@ -108,10 +107,6 @@ def default_value(self, t : Type, default_value) -> Exp:

def check_wf(self, e : Exp, is_valid, state_vars : {EVar}, args : {EVar}, pool = RUNTIME_POOL, assumptions : Exp = ETRUE):
"""Return None or a string indicating a well-formedness error."""
if (isinstance(e, EHeapPeek) or isinstance(e, EHeapPeek2)):
heap = e.e
if not is_valid(EEq(e.heap_length, ELen(EHeapElems(heap).with_type(TBag(heap.type.elem_type))))):
return "invalid `n` parameter"
return None

def possibly_useful(self,
Expand Down Expand Up @@ -142,14 +137,10 @@ def typecheck(self, e : Exp, typecheck, report_err):
e.type = TMinHeap(e.e.type.elem_type, e.key_function.body.type)
elif isinstance(e, EHeapPeek) or isinstance(e, EHeapPeek2):
typecheck(e.e)
typecheck(e.heap_length)
ok = True
if not (isinstance(e.e.type, TMinHeap) or isinstance(e.e.type, TMaxHeap)):
report_err(e, "cannot peek a non-heap")
ok = False
if e.heap_length.type != INT:
report_err(e, "length param is not an int")
ok = False
if ok:
e.type = e.e.type.elem_type
elif isinstance(e, EHeapElems):
Expand Down Expand Up @@ -269,7 +260,7 @@ def mutate_in_place(self, lval, e, op, assumptions, invariants, make_subgoal):
SForEach(v, modified, SCall(lval, "update", (v, make_subgoal(new_v_key, a=[EIn(v, mod_spec)]))))])

def rep_type(self, t : Type) -> Type:
return TArray(t.elem_type)
return TTuple((INT, TArray(t.elem_type)))

def codegen(self, e : Exp, concretization_functions : { str : Exp }, out : EVar) -> Stm:
"""Return statements that write the result of `e` to `out`.
Expand All @@ -279,10 +270,10 @@ def codegen(self, e : Exp, concretization_functions : { str : Exp }, out : EVar)
"""
if isinstance(e, EMakeMinHeap) or isinstance(e, EMakeMaxHeap):
out_raw = EVar(out.id).with_type(self.rep_type(e.type))
l = fresh_var(INT, "alloc_len")
elem_type = e.type.elem_type
return seq([
SDecl(l, ELen(e.e)),
SArrayAlloc(out_raw, l),
SAssign(out_raw, ETuple([ELen(e.e), EArrayList().with_type(TArray(elem_type))]).with_type(self.rep_type(e.type))),
SArrayAlloc(ETupleGet(out_raw, 1).with_type(TArray(elem_type)), ELen(e.e)),
SCall(out, "add_all", (ZERO, e.e))])
elif isinstance(e, EHeapElems):
elem_type = e.type.elem_type
Expand All @@ -292,20 +283,21 @@ def codegen(self, e : Exp, concretization_functions : { str : Exp }, out : EVar)
i = fresh_var(INT, "i") # the array index
return seq([
SDecl(i, ZERO),
SWhile(ELt(i, EArrayLen(e.e).with_type(INT)), seq([
SCall(out, "add", (EArrayGet(e.e, i).with_type(elem_type),)),
SWhile(ELt(i, ETupleGet(e.e, 0).with_type(INT)), seq([
SCall(out, "add", (EArrayGet(ETupleGet(e.e, 1), i).with_type(elem_type),)),
SAssign(i, EBinOp(i, "+", ONE).with_type(INT))]))])
elif isinstance(e, EHeapPeek):
raise NotImplementedError()
elif isinstance(e, EHeapPeek2):
from cozy.evaluation import construct_value
best = EArgMin if isinstance(e.e.type, TMinHeap) else EArgMax
f = heap_func(e.e, concretization_functions)
return SSwitch(e.heap_length, (
return SSwitch(ETupleGet(e.e, 0), (
(ZERO, SAssign(out, construct_value(e.type))),
(ONE, SAssign(out, construct_value(e.type))),
(TWO, SAssign(out, EArrayGet(e.e, ONE).with_type(e.type)))),
SAssign(out, best(EBinOp(ESingleton(EArrayGet(e.e, ONE).with_type(e.type)).with_type(TBag(out.type)), "+", ESingleton(EArrayGet(e.e, TWO).with_type(e.type)).with_type(TBag(out.type))).with_type(TBag(out.type)), f).with_type(out.type)))
(TWO, SAssign(out, EArrayGet(ETupleGet(e.e, 1), ONE).with_type(e.type)))),
SAssign(out, best(EBinOp(ESingleton(EArrayGet(ETupleGet(e.e, 1), ONE).with_type(e.type)).with_type(TBag(out.type)), "+",
ESingleton(EArrayGet(ETupleGet(e.e, 1), TWO).with_type(e.type)).with_type(TBag(out.type))).with_type(TBag(out.type)), f).with_type(out.type)))
else:
raise NotImplementedError(e)

Expand All @@ -321,21 +313,24 @@ def implement_stmt(self, s : Stm, concretization_functions : { str : Exp }) -> S
if isinstance(s, SCall):
elem_type = s.target.type.elem_type
target_raw = EVar(s.target.id).with_type(self.rep_type(s.target.type))
target_len = ETupleGet(target_raw, 0).with_type(INT)
target_array = ETupleGet(target_raw, 1).with_type(TArray(elem_type))
if s.func == "add_all":
size = fresh_var(INT, "heap_size")
i = fresh_var(INT, "i")
x = fresh_var(elem_type, "x")
return seq([
SDecl(size, s.args[0]),
SEnsureCapacity(target_raw, EBinOp(size, "+", ELen(s.args[1])).with_type(INT)),
SEnsureCapacity(target_array, EBinOp(size, "+", ELen(s.args[1])).with_type(INT)),
SForEach(x, s.args[1], seq([
SAssign(EArrayGet(target_raw, size).with_type(elem_type), x),
SAssign(target_len, EBinOp(target_len, "+", ONE).with_type(INT)),
SAssign(EArrayGet(target_array, size).with_type(elem_type), x),
SDecl(i, size),
SWhile(EAll([
EBinOp(i, ">", ZERO).with_type(BOOL),
ENot(EBinOp(f.apply_to(EArrayGet(target_raw, _parent(i)).with_type(elem_type)), comparison_op, f.apply_to(EArrayGet(target_raw, i).with_type(elem_type))).with_type(BOOL))]),
ENot(EBinOp(f.apply_to(EArrayGet(target_array, _parent(i)).with_type(elem_type)), comparison_op, f.apply_to(EArrayGet(target_array, i).with_type(elem_type))).with_type(BOOL))]),
seq([
SSwap(EArrayGet(target_raw, _parent(i)).with_type(elem_type), EArrayGet(target_raw, i).with_type(elem_type)),
SSwap(EArrayGet(target_array, _parent(i)).with_type(elem_type), EArrayGet(target_array, i).with_type(elem_type)),
SAssign(i, _parent(i))])),
SAssign(size, EBinOp(size, "+", ONE).with_type(INT))]))])
elif s.func == "remove_all":
Expand All @@ -348,19 +343,20 @@ def implement_stmt(self, s : Stm, concretization_functions : { str : Exp }) -> S
return seq([
SDecl(size, s.args[0]),
SForEach(x, s.args[1], seq([
SAssign(target_len, EBinOp(target_len, "-", ONE).with_type(INT)),
# find the element to remove
SDecl(i, EArrayIndexOf(target_raw, x).with_type(INT)),
SDecl(i, EArrayIndexOf(target_array, x).with_type(INT)),
# swap with last element in heap
SSwap(EArrayGet(target_raw, i).with_type(elem_type), EArrayGet(target_raw, size_minus_one).with_type(elem_type)),
SSwap(EArrayGet(target_array, i).with_type(elem_type), EArrayGet(target_array, size_minus_one).with_type(elem_type)),
# bubble down
SEscapableBlock(label, SWhile(_has_left_child(i, size_minus_one), seq([
SDecl(child_index, _left_child(i)),
SIf(EAll([_has_right_child(i, size_minus_one), ENot(EBinOp(f.apply_to(EArrayGet(target_raw, _left_child(i)).with_type(elem_type)), comparison_op, f.apply_to(EArrayGet(target_raw, _right_child(i)).with_type(elem_type))))]),
SIf(EAll([_has_right_child(i, size_minus_one), ENot(EBinOp(f.apply_to(EArrayGet(target_array, _left_child(i)).with_type(elem_type)), comparison_op, f.apply_to(EArrayGet(target_array, _right_child(i)).with_type(elem_type))))]),
SAssign(child_index, _right_child(i)),
SNoOp()),
SIf(ENot(EBinOp(f.apply_to(EArrayGet(target_raw, i).with_type(elem_type)), comparison_op, f.apply_to(EArrayGet(target_raw, child_index).with_type(elem_type)))),
SIf(ENot(EBinOp(f.apply_to(EArrayGet(target_array, i).with_type(elem_type)), comparison_op, f.apply_to(EArrayGet(target_array, child_index).with_type(elem_type)))),
seq([
SSwap(EArrayGet(target_raw, i).with_type(elem_type), EArrayGet(target_raw, child_index).with_type(elem_type)),
SSwap(EArrayGet(target_array, i).with_type(elem_type), EArrayGet(target_array, child_index).with_type(elem_type)),
SAssign(i, child_index)]),
SEscapeBlock(label))]))),
# dec. size
Expand Down
2 changes: 1 addition & 1 deletion cozy/synthesis/acceleration.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def optimized_best(xs, keyfunc, op, args):
h = make_heap(bag.e, keyfunc).with_type(heap_type(elem_type, key_type))
for prev_min in optimized_best(bag.e, keyfunc, op, args=args):
prev_min = EStateVar(prev_min).with_type(elem_type)
heap_peek = EHeapPeek2(EStateVar(h).with_type(h.type), EStateVar(ELen(bag.e)).with_type(INT)).with_type(elem_type)
heap_peek = EHeapPeek2(EStateVar(h).with_type(h.type)).with_type(elem_type)
conds = [optimized_in(x, bag), optimized_eq(x, prev_min)]
if isinstance(x, EUnaryOp) and x.op == UOp.The:
conds = [optimized_exists(x.e)] + conds
Expand Down
16 changes: 14 additions & 2 deletions tests/codegen.py

Large diffs are not rendered by default.

Loading

0 comments on commit a95628b

Please sign in to comment.