Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix EMapKeys codegen failure #112

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cozy/codegen/cxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Type, INT, BOOL, TNative, TSet, TList, TBag, THandle, TEnum, TTuple, TRecord, TFloat,
Exp, EVar, ENum, EFALSE, ETRUE, ZERO, ENull, EEq, ELt, ENot, ECond, EAll,
EEnumEntry, ETuple, ETupleGet, EGetField,
Stm, SNoOp, SIf, SDecl, SSeq, seq, SForEach, SAssign)
Stm, SNoOp, SIf, SDecl, SSeq, seq, SForEach, SAssign, SCall)
from cozy.target_syntax import TArray, TRef, EEnumToInt, EMapKeys, SReturn
from cozy.syntax_tools import pprint, all_types, fresh_var, subst, free_vars, all_exps, break_seq, shallow_copy
from cozy.typecheck import is_collection, is_scalar
Expand Down Expand Up @@ -272,6 +272,14 @@ def visit_EMakeMap2(self, e):
self.declare(m, e)
return m.id

def visit_EMapKeys(self, e):
key = self.fv(e.type.elem_type)
keys = self.fv(e.type)
self.declare(keys)
add_to_keys = SCall(keys, "add", [key])
self.visit(SForEach(key, e, add_to_keys))
return keys.id

def visit_EHasKey(self, e):
map = self.visit(e.map)
key = self.visit(e.key)
Expand Down
2 changes: 1 addition & 1 deletion cozy/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def __init__(self, parent : Context, v : EVar, bag : Exp, bag_pool : Pool):

`v` must not already be described by the parent context.
"""
assert v.type == bag.type.elem_type
assert v.type == bag.type.elem_type, "%s, %s" % (v.type, bag.type)
assert parent.legal_for(free_vars(bag)), "cannot create context for {} in {}, {}".format(v.id, pprint(bag), parent)
assert not any(v == vv for vv, p in parent.vars()), "binder {} already free in {}".format(v.id, parent)
self._parent = parent
Expand Down
20 changes: 18 additions & 2 deletions cozy/synthesis/acceleration.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,13 @@ def optimized_map(xs, f, args):

sum_of = lambda xs: EUnaryOp(UOp.Sum, xs).with_type(xs.type.elem_type)

def optimized_sum(xs, args):
def optimized_sum(xs, args, inplace_opt=True):
"""
:param xs:
:param args:
:param inplace_opt: whether to apply inplace optimization, which might cause infinite recursive
:return:
"""
elem_type = xs.type.elem_type
if isinstance(xs, EStateVar):
yield EStateVar(sum_of(strip_EStateVar(xs))).with_type(elem_type)
Expand All @@ -653,12 +659,22 @@ def optimized_sum(xs, args):
yield xs.e
if isinstance(xs, EFlatMap):
f = xs.transform_function
# sum flatMap(e1 + e2, f) == sum flatMap(e1, f) + sum flatMap(e2, f)
if isinstance(f.body, EBinOp) and f.body.op == "+":
for e1 in optimized_flatmap(xs.e, ELambda(f.arg, f.body.e1), args):
for e2 in optimized_flatmap(xs.e, ELambda(f.arg, f.body.e2), args):
for e in optimized_sum(EBinOp(e1, "+", e2).with_type(e1.type), args):
yield e

# sum flatMap(R, \r -> map(S, \s -> g(r, s))) == sum flatMap(S, \s -> map(R, \r -> g(r, s)))
if isinstance(f.body, EMap) and inplace_opt:
R = xs.e
r = f.arg
S = f.body.e
s = f.body.transform_function.arg
g = f.body.transform_function.body
for e in optimized_flatmap(S, ELambda(s, EMap(R, ELambda(r, g)).with_type(type(R.type)(g.type))), args):
for e2 in optimized_sum(e, args, inplace_opt=False):
yield e2
yield sum_of(xs)

def optimized_flatmap(xs, f, args):
Expand Down
2 changes: 1 addition & 1 deletion examples/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ select.h: select-flatmap.ds
cozy -t $(TIMEOUT) --allow-big-sets select-flatmap.ds --c++ select.h -p 8080 --verbose --save select.synthesized

listcomp: listcomp.cpp listcomp.h
g++ -std=c++11 -O3 -Werror '$<' -o '$@'
g++ -std=c++11 -O3 -Wno-parentheses-equality -Werror '$<' -o '$@'

run-listcomp: listcomp
time ./listcomp
3 changes: 3 additions & 0 deletions examples/listcomp-flatmap.ds
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ ListComp02:

op insert_s(s: S)
Ss.add(s);

op delete_r(a: Int)
Rs.remove_all([ r | r <- Rs, r.A == a ]);