Skip to content

Commit

Permalink
fix EHeapElems bug by combining length member with heap array
Browse files Browse the repository at this point in the history
  • Loading branch information
izgzhen committed Nov 1, 2018
1 parent 9d6582f commit 3b71f13
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 51 deletions.
49 changes: 38 additions & 11 deletions cozy/codegen/cxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from cozy import common, evaluation
from cozy.common import fresh_name, declare_case, extend
from cozy.target_syntax import *
from cozy.syntax_tools import all_types, fresh_var, subst, free_vars, all_exps, break_seq, is_lvalue
from cozy.syntax_tools import all_types, fresh_var, subst, free_vars, all_exps, break_seq, is_lvalue, shallow_copy
from cozy.typecheck import is_collection, is_scalar
from cozy.structures import extension_handler
from cozy.structures.arrays import TArray, EArrayGet

from .misc import *

Expand Down Expand Up @@ -348,7 +349,9 @@ def visit_EListGet(self, e):
evaluation.construct_value(e.type)).with_type(e.type))).with_type(e.type))

def visit_EArrayIndexOf(self, e):
assert isinstance(e.a, EVar) # TODO: make this fast when this is false
if isinstance(e.a, EVar): pass
elif isinstance(e.a, ETupleGet) and isinstance(e.a.e, EVar): pass
else: raise NotImplementedError("finding index of non-var array") # TODO: make this fast when this is false
it = self.fv(TNative("{}::const_iterator".format(self.visit(e.a.type, "").strip())), "cursor")
res = self.fv(INT, "index")
self.visit(seq([
Expand All @@ -374,6 +377,9 @@ def visit_EMakeMap2(self, e):
self.declare(m, e)
return m.id

def visit_EArrayList(self, e):
return "(std::vector< {} >())".format(self.visit(e.type.elem_type, ""))

def visit_EHasKey(self, e):
map = self.visit(e.map)
key = self.visit(e.key)
Expand Down Expand Up @@ -818,7 +824,7 @@ def visit_SAssign(self, s):
def visit_SDecl(self, s):
assert isinstance(s.var, EVar)
t = s.val.type
return self.declare(s.var.with_type(t), s.val)
return self.declare(shallow_copy(s.var).with_type(t), s.val)

def visit_SSeq(self, s):
for ss in break_seq(s):
Expand Down Expand Up @@ -980,6 +986,12 @@ def setup_types(self, spec, state_exps, sharing):
if t not in self.types and type(t) in [THandle, TRecord, TTuple, TEnum]:
name = names.get(t, self.fn("Type"))
self.types[t] = name
# add representation type for extension data structures
h = extension_handler(type(t))
if t not in self.types and h is not None:
name = names.get(t, self.fn("Type"))
t = h.rep_type(t)
self.types[t] = name

def visit_args(self, args):
for (i, (v, t)) in enumerate(args):
Expand All @@ -998,17 +1010,32 @@ def _hasher(self, t : Type) -> str:
except KeyError:
return "std::hash<{}>".format(self.visit(t, ""))

def compute_hash_1(self, e : Exp) -> Exp:
return EEscape("{hasher}()({{e}})".format(hasher=self._hasher(e.type)), ("e",), (e,)).with_type(TNative("std::size_t"))
def compute_hash_scalar(self, e: Exp) -> Exp:
return EEscape("{hasher}()({{e}})".format(hasher=self._hasher(e.type)), ("e",), (e,)).with_type(INT)

def compute_hash_1(self, hc: Exp, e : Exp) -> Stm:
if is_scalar(e.type):
return SAssign(hc, self.compute_hash_scalar(e))
elif isinstance(e.type, TArray):
x = fresh_var(e.type.elem_type, "x")
s = SSeq(SAssign(hc, ZERO.with_type(hc.type)),
SForEach(x, e,
SAssign(hc, EEscape("({hc} * 31) ^ ({h})", ("hc", "h"),
(hc, self.compute_hash_scalar(x))).with_type(INT))))
return s
else:
raise NotImplementedError("can't compute hash for type {}".format(e.type))

def compute_hash(self, fields : [Exp]) -> Stm:
hc = self.fv(TNative("std::size_t"), "hash_code")
s = SDecl(hc, ENum(0).with_type(hc.type))
hc = self.fv(INT, "hash_code")
h = self.fv(INT, "hash_code")
s = SSeq(SDecl(hc, ENum(0).with_type(hc.type)),
SDecl(h, ENum(0).with_type(h.type)))
for f in fields:
# return SAssign(out, )
s = SSeq(s, SAssign(hc,
EEscape("({hc} * 31) ^ ({h})", ("hc", "h"),
(hc, self.compute_hash_1(f))).with_type(TNative("std::size_t"))))
s = seq([s,
self.compute_hash_1(h, f),
SAssign(hc,
EEscape("({hc} * 31) ^ ({h})", ("hc", "h"), (hc, h)).with_type(INT))])
s = SSeq(s, SEscape("{indent}return {e};\n", ("e",), (hc,)))
return s

Expand Down
3 changes: 3 additions & 0 deletions cozy/codegen/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def body(x):
self.end_statement()
self.end_statement()

def visit_EArrayList(self, e):
return "(new {}[0])".format(self.visit(e.type.elem_type, ""))

def initialize_native_list(self, out):
init = "new {};\n".format(self.visit(out.type, name="()"))
return SEscape("{indent}{e} = " + init, ["e"], [out])
Expand Down
1 change: 1 addition & 0 deletions cozy/structures/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from cozy.syntax import Type, Exp, Stm

TArray = declare_case(Type, "TArray", ["elem_type"])
EArrayList = declare_case(Exp, "EArrayList", [])
EArrayCapacity = declare_case(Exp, "EArrayCapacity", ["e"])
EArrayLen = declare_case(Exp, "EArrayLen", ["e"])
EArrayGet = declare_case(Exp, "EArrayGet", ["a", "i"])
Expand Down
60 changes: 28 additions & 32 deletions cozy/structures/heaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from cozy.syntax_tools import fresh_var, pprint, mk_lambda, alpha_equivalent
from cozy.pools import Pool, RUNTIME_POOL, STATE_POOL

from .arrays import TArray, EArrayGet, EArrayIndexOf, SArrayAlloc, SEnsureCapacity, EArrayLen
from .arrays import TArray, EArrayGet, EArrayIndexOf, SArrayAlloc, SEnsureCapacity, EArrayLen, EArrayList

TMinHeap = declare_case(Type, "TMinHeap", ["elem_type", "key_type"])
TMaxHeap = declare_case(Type, "TMaxHeap", ["elem_type", "key_type"])
Expand All @@ -14,8 +14,8 @@
EMakeMaxHeap = declare_case(Exp, "EMakeMaxHeap", ["e", "key_function"])

EHeapElems = declare_case(Exp, "EHeapElems", ["e"]) # all elements
EHeapPeek = declare_case(Exp, "EHeapPeek", ["e", "heap_length"]) # look at min
EHeapPeek2 = declare_case(Exp, "EHeapPeek2", ["e", "heap_length"]) # look at 2nd min
EHeapPeek = declare_case(Exp, "EHeapPeek", ["e"]) # look at min
EHeapPeek2 = declare_case(Exp, "EHeapPeek2", ["e"]) # look at 2nd min

def to_heap(e : Exp) -> Exp:
"""Construct a heap that would be useful for evaluating `e`.
Expand Down Expand Up @@ -90,10 +90,9 @@ def enumerate(self, context, size, pool, enumerate_subexps, enumerate_lambdas):
t = e.type
if isinstance(t, TMinHeap) or isinstance(t, TMaxHeap):
elem_type = t.elem_type
n_elems = EStateVar(ELen(EHeapElems(e).with_type(TBag(elem_type)))).with_type(INT)
# yielding EHeapElems would be redundant
yield EHeapPeek (EStateVar(e).with_type(e.type), n_elems).with_type(elem_type)
yield EHeapPeek2(EStateVar(e).with_type(e.type), n_elems).with_type(elem_type)
yield EHeapPeek (EStateVar(e).with_type(e.type)).with_type(elem_type)
yield EHeapPeek2(EStateVar(e).with_type(e.type)).with_type(elem_type)

def default_value(self, t : Type, default_value) -> Exp:
"""Construct a default value of the given type.
Expand All @@ -108,10 +107,6 @@ def default_value(self, t : Type, default_value) -> Exp:

def check_wf(self, e : Exp, is_valid, state_vars : {EVar}, args : {EVar}, pool = RUNTIME_POOL, assumptions : Exp = ETRUE):
"""Return None or a string indicating a well-formedness error."""
if (isinstance(e, EHeapPeek) or isinstance(e, EHeapPeek2)):
heap = e.e
if not is_valid(EEq(e.heap_length, ELen(EHeapElems(heap).with_type(TBag(heap.type.elem_type))))):
return "invalid `n` parameter"
return None

def possibly_useful(self,
Expand Down Expand Up @@ -142,14 +137,10 @@ def typecheck(self, e : Exp, typecheck, report_err):
e.type = TMinHeap(e.e.type.elem_type, e.key_function.body.type)
elif isinstance(e, EHeapPeek) or isinstance(e, EHeapPeek2):
typecheck(e.e)
typecheck(e.heap_length)
ok = True
if not (isinstance(e.e.type, TMinHeap) or isinstance(e.e.type, TMaxHeap)):
report_err(e, "cannot peek a non-heap")
ok = False
if e.heap_length.type != INT:
report_err(e, "length param is not an int")
ok = False
if ok:
e.type = e.e.type.elem_type
elif isinstance(e, EHeapElems):
Expand Down Expand Up @@ -269,7 +260,7 @@ def mutate_in_place(self, lval, e, op, assumptions, invariants, make_subgoal):
SForEach(v, modified, SCall(lval, "update", (v, make_subgoal(new_v_key, a=[EIn(v, mod_spec)]))))])

def rep_type(self, t : Type) -> Type:
return TArray(t.elem_type)
return TTuple((INT, TArray(t.elem_type)))

def codegen(self, e : Exp, concretization_functions : { str : Exp }, out : EVar) -> Stm:
"""Return statements that write the result of `e` to `out`.
Expand All @@ -279,10 +270,10 @@ def codegen(self, e : Exp, concretization_functions : { str : Exp }, out : EVar)
"""
if isinstance(e, EMakeMinHeap) or isinstance(e, EMakeMaxHeap):
out_raw = EVar(out.id).with_type(self.rep_type(e.type))
l = fresh_var(INT, "alloc_len")
elem_type = e.type.elem_type
return seq([
SDecl(l, ELen(e.e)),
SArrayAlloc(out_raw, l),
SAssign(out_raw, ETuple([ELen(e.e), EArrayList().with_type(TArray(elem_type))]).with_type(self.rep_type(e.type))),
SArrayAlloc(ETupleGet(out_raw, 1).with_type(TArray(elem_type)), ELen(e.e)),
SCall(out, "add_all", (ZERO, e.e))])
elif isinstance(e, EHeapElems):
elem_type = e.type.elem_type
Expand All @@ -292,20 +283,21 @@ def codegen(self, e : Exp, concretization_functions : { str : Exp }, out : EVar)
i = fresh_var(INT, "i") # the array index
return seq([
SDecl(i, ZERO),
SWhile(ELt(i, EArrayLen(e.e).with_type(INT)), seq([
SCall(out, "add", (EArrayGet(e.e, i).with_type(elem_type),)),
SWhile(ELt(i, ETupleGet(e.e, 0).with_type(INT)), seq([
SCall(out, "add", (EArrayGet(ETupleGet(e.e, 1), i).with_type(elem_type),)),
SAssign(i, EBinOp(i, "+", ONE).with_type(INT))]))])
elif isinstance(e, EHeapPeek):
raise NotImplementedError()
elif isinstance(e, EHeapPeek2):
from cozy.evaluation import construct_value
best = EArgMin if isinstance(e.e.type, TMinHeap) else EArgMax
f = heap_func(e.e, concretization_functions)
return SSwitch(e.heap_length, (
return SSwitch(ETupleGet(e.e, 0), (
(ZERO, SAssign(out, construct_value(e.type))),
(ONE, SAssign(out, construct_value(e.type))),
(TWO, SAssign(out, EArrayGet(e.e, ONE).with_type(e.type)))),
SAssign(out, best(EBinOp(ESingleton(EArrayGet(e.e, ONE).with_type(e.type)).with_type(TBag(out.type)), "+", ESingleton(EArrayGet(e.e, TWO).with_type(e.type)).with_type(TBag(out.type))).with_type(TBag(out.type)), f).with_type(out.type)))
(TWO, SAssign(out, EArrayGet(ETupleGet(e.e, 1), ONE).with_type(e.type)))),
SAssign(out, best(EBinOp(ESingleton(EArrayGet(ETupleGet(e.e, 1), ONE).with_type(e.type)).with_type(TBag(out.type)), "+",
ESingleton(EArrayGet(ETupleGet(e.e, 1), TWO).with_type(e.type)).with_type(TBag(out.type))).with_type(TBag(out.type)), f).with_type(out.type)))
else:
raise NotImplementedError(e)

Expand All @@ -321,21 +313,24 @@ def implement_stmt(self, s : Stm, concretization_functions : { str : Exp }) -> S
if isinstance(s, SCall):
elem_type = s.target.type.elem_type
target_raw = EVar(s.target.id).with_type(self.rep_type(s.target.type))
target_len = ETupleGet(target_raw, 0).with_type(INT)
target_array = ETupleGet(target_raw, 1).with_type(TArray(elem_type))
if s.func == "add_all":
size = fresh_var(INT, "heap_size")
i = fresh_var(INT, "i")
x = fresh_var(elem_type, "x")
return seq([
SDecl(size, s.args[0]),
SEnsureCapacity(target_raw, EBinOp(size, "+", ELen(s.args[1])).with_type(INT)),
SEnsureCapacity(target_array, EBinOp(size, "+", ELen(s.args[1])).with_type(INT)),
SForEach(x, s.args[1], seq([
SAssign(EArrayGet(target_raw, size).with_type(elem_type), x),
SAssign(target_len, EBinOp(target_len, "+", ONE).with_type(INT)),
SAssign(EArrayGet(target_array, size).with_type(elem_type), x),
SDecl(i, size),
SWhile(EAll([
EBinOp(i, ">", ZERO).with_type(BOOL),
ENot(EBinOp(f.apply_to(EArrayGet(target_raw, _parent(i)).with_type(elem_type)), comparison_op, f.apply_to(EArrayGet(target_raw, i).with_type(elem_type))).with_type(BOOL))]),
ENot(EBinOp(f.apply_to(EArrayGet(target_array, _parent(i)).with_type(elem_type)), comparison_op, f.apply_to(EArrayGet(target_array, i).with_type(elem_type))).with_type(BOOL))]),
seq([
SSwap(EArrayGet(target_raw, _parent(i)).with_type(elem_type), EArrayGet(target_raw, i).with_type(elem_type)),
SSwap(EArrayGet(target_array, _parent(i)).with_type(elem_type), EArrayGet(target_array, i).with_type(elem_type)),
SAssign(i, _parent(i))])),
SAssign(size, EBinOp(size, "+", ONE).with_type(INT))]))])
elif s.func == "remove_all":
Expand All @@ -348,19 +343,20 @@ def implement_stmt(self, s : Stm, concretization_functions : { str : Exp }) -> S
return seq([
SDecl(size, s.args[0]),
SForEach(x, s.args[1], seq([
SAssign(target_len, EBinOp(target_len, "-", ONE).with_type(INT)),
# find the element to remove
SDecl(i, EArrayIndexOf(target_raw, x).with_type(INT)),
SDecl(i, EArrayIndexOf(target_array, x).with_type(INT)),
# swap with last element in heap
SSwap(EArrayGet(target_raw, i).with_type(elem_type), EArrayGet(target_raw, size_minus_one).with_type(elem_type)),
SSwap(EArrayGet(target_array, i).with_type(elem_type), EArrayGet(target_array, size_minus_one).with_type(elem_type)),
# bubble down
SEscapableBlock(label, SWhile(_has_left_child(i, size_minus_one), seq([
SDecl(child_index, _left_child(i)),
SIf(EAll([_has_right_child(i, size_minus_one), ENot(EBinOp(f.apply_to(EArrayGet(target_raw, _left_child(i)).with_type(elem_type)), comparison_op, f.apply_to(EArrayGet(target_raw, _right_child(i)).with_type(elem_type))))]),
SIf(EAll([_has_right_child(i, size_minus_one), ENot(EBinOp(f.apply_to(EArrayGet(target_array, _left_child(i)).with_type(elem_type)), comparison_op, f.apply_to(EArrayGet(target_array, _right_child(i)).with_type(elem_type))))]),
SAssign(child_index, _right_child(i)),
SNoOp()),
SIf(ENot(EBinOp(f.apply_to(EArrayGet(target_raw, i).with_type(elem_type)), comparison_op, f.apply_to(EArrayGet(target_raw, child_index).with_type(elem_type)))),
SIf(ENot(EBinOp(f.apply_to(EArrayGet(target_array, i).with_type(elem_type)), comparison_op, f.apply_to(EArrayGet(target_array, child_index).with_type(elem_type)))),
seq([
SSwap(EArrayGet(target_raw, i).with_type(elem_type), EArrayGet(target_raw, child_index).with_type(elem_type)),
SSwap(EArrayGet(target_array, i).with_type(elem_type), EArrayGet(target_array, child_index).with_type(elem_type)),
SAssign(i, child_index)]),
SEscapeBlock(label))]))),
# dec. size
Expand Down
2 changes: 1 addition & 1 deletion cozy/synthesis/acceleration.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def optimized_best(xs, keyfunc, op, args):
h = make_heap(bag.e, keyfunc).with_type(heap_type(elem_type, key_type))
for prev_min in optimized_best(bag.e, keyfunc, op, args=args):
prev_min = EStateVar(prev_min).with_type(elem_type)
heap_peek = EHeapPeek2(EStateVar(h).with_type(h.type), EStateVar(ELen(bag.e)).with_type(INT)).with_type(elem_type)
heap_peek = EHeapPeek2(EStateVar(h).with_type(h.type)).with_type(elem_type)
conds = [optimized_in(x, bag), optimized_eq(x, prev_min)]
if isinstance(x, EUnaryOp) and x.op == UOp.The:
conds = [optimized_exists(x.e)] + conds
Expand Down
3 changes: 2 additions & 1 deletion cozy/typecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from cozy.syntax import BOOL, INT, LONG, FLOAT, STRING
from cozy.structures import extension_handler
from cozy.structures.arrays import TArray

def typecheck(
ast,
Expand Down Expand Up @@ -100,7 +101,7 @@ def equality_implies_deep_equality(t : syntax.Type):
return equality_implies_deep_equality(t.elem_type)
return False

COLLECTION_TYPES = (syntax.TBag, syntax.TSet, syntax.TList)
COLLECTION_TYPES = (syntax.TBag, syntax.TSet, syntax.TList, TArray)
def is_collection(t):
return any(isinstance(t, ct) for ct in COLLECTION_TYPES)

Expand Down
16 changes: 14 additions & 2 deletions tests/codegen.py

Large diffs are not rendered by default.

Loading

0 comments on commit 3b71f13

Please sign in to comment.