Skip to content

Commit

Permalink
Merge pull request #40 from CozySynthesizer/let-expressions
Browse files Browse the repository at this point in the history
Synthesize let-expressions
  • Loading branch information
Calvin-L authored Jul 3, 2018
2 parents 7eebc85 + 53b5a8a commit c46b02a
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 1 deletion.
8 changes: 7 additions & 1 deletion cozy/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import itertools

from cozy.common import OrderedSet, unique, Visitor
from cozy.syntax import TFunc, Exp, EVar, EAll
from cozy.syntax import TFunc, TBag, Exp, EVar, EAll, ESingleton
from cozy.target_syntax import EDeepIn
from cozy.evaluation import eval
from cozy.syntax_tools import pprint, alpha_equivalent, free_vars, subst, BottomUpRewriter
Expand Down Expand Up @@ -202,6 +202,10 @@ def visit_EMakeMaxHeap(self, e):
yield (e, self.ctx, self.pool)
yield from self.visit(e.e)
yield from self.visit(e.f, e.e)
def visit_ELet(self, e):
yield (e, self.ctx, self.pool)
yield from self.visit(e.e)
yield from self.visit(e.f, ESingleton(e.e).with_type(TBag(e.e.type)))
def visit_Exp(self, e):
yield (e, self.ctx, self.pool)
for child in e.children():
Expand Down Expand Up @@ -269,6 +273,8 @@ def visit_EMakeMinHeap(self, e):
return self.join(e, (self.visit(e.e), self.visit(e.f, e.e)))
def visit_EMakeMaxHeap(self, e):
return self.join(e, (self.visit(e.e), self.visit(e.f, e.e)))
def visit_ELet(self, e):
return self.join(e, (self.visit(e.e), self.visit(e.f, ESingleton(e.e).with_type(TBag(e.e.type)))))
def visit(self, e, *args):
if isinstance(e, Exp) and _sametype(e, self.needle) and self.pool == self.needle_pool and alpha_equivalent(self.needle, e) and self.needle_context.alpha_equivalent(self.ctx):
return self.ctx.adapt(self.replacement, self.needle_context)
Expand Down
4 changes: 4 additions & 0 deletions cozy/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,10 @@ def rt(e, account_for_constant_factors=True):
rt(e.then_branch),
rt(e.else_branch)).with_type(INT))
continue
if isinstance(e, ELet):
stk.append(e.e)
terms.append(ELet(e.e, ELambda(e.f.arg, rt(e.f.body))).with_type(INT))
continue

constant += 1
if isinstance(e, EStateVar):
Expand Down
10 changes: 10 additions & 0 deletions cozy/synthesis/enumeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,16 @@ def build_lambdas(bag, pool, body_size):
for lam_body in self.enumerate(inner_context, body_size, pool):
yield ELambda(v, lam_body)

# Let-expressions
for (sz1, sz2) in pick_to_sum(2, size - 1):
for x in self.enumerate(context, sz1, pool):
bag = ESingleton(x).with_type(TBag(x.type))
for lam in build_lambdas(bag, pool, sz2):
e = ELet(x, lam).with_type(lam.body.type)
# if x == EBinOp(EVar("x"), "+", EVar("x")):
# e._tag = True
yield e

# Iteration
for (sz1, sz2) in pick_to_sum(2, size - 1):
for bag in collections(self.enumerate(context, sz1, pool)):
Expand Down
15 changes: 15 additions & 0 deletions tests/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,18 @@ def test_pool_affects_alpha_equivalence(self):

assert c1 != c2
assert not c1.alpha_equivalent(c2)

def test_let(self):
e1 = ELet(ZERO, ELambda(x, x))
root_ctx = RootCtx(args=(), state_vars=())
assert retypecheck(e1)
n = 0
for ee, ctx, pool in shred(e1, root_ctx, RUNTIME_POOL):
if ee == x:
e2 = replace(
e1, root_ctx, RUNTIME_POOL,
x, ctx, pool,
ZERO)
assert e2 == ELet(ZERO, ELambda(x, ZERO))
n += 1
assert n == 1
Loading

0 comments on commit c46b02a

Please sign in to comment.