Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental @alpha_rename decorator for Python functions #532

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 93 additions & 24 deletions funsor/syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ast
import functools
import inspect
import itertools

from . import ops

Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self, infix, prefix, const):
self.infix = {INFIX_TO_NODE[k]: v for k, v in infix.items()}
self.prefix = {PREFIX_TO_NODE[k]: v for k, v in prefix.items()}
self.const = const
super().__init__()

def visit_Constant(self, node):
node = self.generic_visit(node)
Expand Down Expand Up @@ -131,38 +133,105 @@ def product_rule(sum_op, prod_op, lhs, rhs, d):
transformer = OpTransformer(infix, prefix, const)

def decorator(fn):
source = inspect.getsource(fn)

# Strip indentation and all decorators.
indent = len(source) - len(source.lstrip())
lines = []
discard = True
for line in source.split("\n"):
line = line[indent:]
if discard:
if line.startswith("def "):
discard = False
else:
continue
lines.append(line)
source = "\n".join(lines)
assert source

# Transform the function.
a = ast.parse(source)
a = decompile_def(fn)
a_t = transformer.visit(a)
source_t = ast.unparse(a_t)
result = {}
exec(source_t, globals(), result)
fn_t = result[fn.__name__]
functools.update_wrapper(fn_t, fn)
fn_t = recompile_def(fn, a_t)
return fn_t

return decorator


def _find_names(count, avoid):
"""
Finds count-many distincy variable names, avoiding names in ``avoid``.

:param avoid: A collection of names to avoid.
:returns: A variable name something like "_bound_123"
:rtype: str
"""
assert isinstance(count, int) and count >= 0
result = []
for i in itertools.count():
if len(result) == count:
return result
name = f"_bound_{i}"
if name not in avoid:
result.append(name)


def alpha_rename(fn=None, locals_=None):
"""
Rename all position-only arguments in a function.
"""
if fn is None:
return functools.partial(alpha_rename, locals_=locals_)

# Create a canonical alpha renaming.
sig = inspect.signature(fn)
old_names = {
name for name, p in sig.parameters.items() if p.kind == p.POSITIONAL_ONLY
}
avoid = set(fn.__code__.co_varnames) - old_names
new_names = _find_names(len(old_names), avoid)
rename = dict(zip(old_names, new_names))

# Rename variables in-place.
a = decompile_def(fn)
for node in ast.walk(a):
if isinstance(node, ast.FunctionDef):
for arg in node.args.posonlyargs:
arg.arg = rename.get(arg.arg, arg.arg)
elif isinstance(node, ast.Name):
node.id = rename.get(node.id, node.id)
fn_t = recompile_def(fn, a, locals_)
return fn_t


def decompile_def(fn):
"""
Decompile a function definition to an ast, dropping all decorators.

:param callable fn:
:returns: an ast representation of ``fn``
:rtype: ast.Module
"""
source = inspect.getsource(fn)

# Strip indentation and all decorators.
indent = len(source) - len(source.lstrip())
lines = []
discard = True
for line in source.split("\n"):
line = line[indent:]
if discard:
if line.startswith("def "):
discard = False
else:
continue
lines.append(line)
source = "\n".join(lines)
assert source

return ast.parse(source)


def recompile_def(fn, a, locals_=None):
"""
Recompile the ast ``a`` to function like ``fn``.
"""
if locals_ is None:
locals_ = {}
source = ast.unparse(a)
exec(source, globals(), locals_)
fn_t = locals_[fn.__name__]
functools.update_wrapper(fn_t, fn, assigned=("__module__",), updated=())
return fn_t


__all__ = [
"INFIX_OPERATORS",
"PREFIX_OPERATORS",
"decompile_def",
"recompile_def",
"rewrite_ops",
]
41 changes: 40 additions & 1 deletion test/test_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from funsor import ops
from funsor.syntax import rewrite_ops
from funsor.syntax import alpha_rename, rewrite_ops

python_version = tuple(map(int, sys.version.split()[0].split(".")[:2]))

Expand All @@ -16,6 +16,7 @@ def assert_fn_eq(actual, expected):
assert actual.__name__ == expected.__name__
assert actual.__doc__ == expected.__doc__
assert actual.__module__ == expected.__module__
assert actual.__closure__ == expected.__closure__
assert actual.__code__.co_code == expected.__code__.co_code


Expand Down Expand Up @@ -106,3 +107,41 @@ def foo(add_op, mul_op, x, y, z):

args = (ops.mul, ops.mul, 1.23, 4.56, 7.89)
assert actual(*args) == expected(*args)


def test_alpha_rename_1():
@alpha_rename
def fn(a, b, /, c, d=1):
e = a + b
return c + d + e

actual = fn

def fn(_bound_0, _bound_1, /, c, d=1):
e = _bound_0 + _bound_1
return c + d + e

expected = fn
assert_fn_eq(actual, expected)

assert actual(1, 2, 3, 4) == expected(1, 2, 3, 4)


@pytest.mark.xfail(reason="failure to reproduce enclosing scope")
def test_alpha_rename_2():
@alpha_rename(locals_=locals())
def fn(a, b, /, c, d=1):
e = a + b
return c + d + e + f

actual = fn

def fn(_bound_0, _bound_1, /, c, d=1):
e = _bound_0 + _bound_1
return c + d + e + f

expected = fn
assert_fn_eq(actual, expected)

f = 0
assert actual(1, 2, 3, 4) == expected(1, 2, 3, 4)