diff --git a/funsor/syntax.py b/funsor/syntax.py index 0caa34cd..d7e9aa13 100644 --- a/funsor/syntax.py +++ b/funsor/syntax.py @@ -4,6 +4,7 @@ import ast import functools import inspect +import itertools from . import ops @@ -46,6 +47,7 @@ def __init__(self, infix, prefix, const): self.infix = {INFIX_TO_NODE[k]: v for k, v in infix.items()} self.prefix = {PREFIX_TO_NODE[k]: v for k, v in prefix.items()} self.const = const + super().__init__() def visit_Constant(self, node): node = self.generic_visit(node) @@ -131,38 +133,105 @@ def product_rule(sum_op, prod_op, lhs, rhs, d): transformer = OpTransformer(infix, prefix, const) def decorator(fn): - source = inspect.getsource(fn) - - # Strip indentation and all decorators. - indent = len(source) - len(source.lstrip()) - lines = [] - discard = True - for line in source.split("\n"): - line = line[indent:] - if discard: - if line.startswith("def "): - discard = False - else: - continue - lines.append(line) - source = "\n".join(lines) - assert source - - # Transform the function. - a = ast.parse(source) + a = decompile_def(fn) a_t = transformer.visit(a) - source_t = ast.unparse(a_t) - result = {} - exec(source_t, globals(), result) - fn_t = result[fn.__name__] - functools.update_wrapper(fn_t, fn) + fn_t = recompile_def(fn, a_t) return fn_t return decorator +def _find_names(count, avoid): + """ + Finds count-many distincy variable names, avoiding names in ``avoid``. + + :param avoid: A collection of names to avoid. + :returns: A variable name something like "_bound_123" + :rtype: str + """ + assert isinstance(count, int) and count >= 0 + result = [] + for i in itertools.count(): + if len(result) == count: + return result + name = f"_bound_{i}" + if name not in avoid: + result.append(name) + + +def alpha_rename(fn=None, locals_=None): + """ + Rename all position-only arguments in a function. + """ + if fn is None: + return functools.partial(alpha_rename, locals_=locals_) + + # Create a canonical alpha renaming. + sig = inspect.signature(fn) + old_names = { + name for name, p in sig.parameters.items() if p.kind == p.POSITIONAL_ONLY + } + avoid = set(fn.__code__.co_varnames) - old_names + new_names = _find_names(len(old_names), avoid) + rename = dict(zip(old_names, new_names)) + + # Rename variables in-place. + a = decompile_def(fn) + for node in ast.walk(a): + if isinstance(node, ast.FunctionDef): + for arg in node.args.posonlyargs: + arg.arg = rename.get(arg.arg, arg.arg) + elif isinstance(node, ast.Name): + node.id = rename.get(node.id, node.id) + fn_t = recompile_def(fn, a, locals_) + return fn_t + + +def decompile_def(fn): + """ + Decompile a function definition to an ast, dropping all decorators. + + :param callable fn: + :returns: an ast representation of ``fn`` + :rtype: ast.Module + """ + source = inspect.getsource(fn) + + # Strip indentation and all decorators. + indent = len(source) - len(source.lstrip()) + lines = [] + discard = True + for line in source.split("\n"): + line = line[indent:] + if discard: + if line.startswith("def "): + discard = False + else: + continue + lines.append(line) + source = "\n".join(lines) + assert source + + return ast.parse(source) + + +def recompile_def(fn, a, locals_=None): + """ + Recompile the ast ``a`` to function like ``fn``. + """ + if locals_ is None: + locals_ = {} + source = ast.unparse(a) + exec(source, globals(), locals_) + fn_t = locals_[fn.__name__] + functools.update_wrapper(fn_t, fn, assigned=("__module__",), updated=()) + return fn_t + + __all__ = [ "INFIX_OPERATORS", "PREFIX_OPERATORS", + "decompile_def", + "recompile_def", "rewrite_ops", ] diff --git a/test/test_syntax.py b/test/test_syntax.py index 241bff85..e0e9e599 100644 --- a/test/test_syntax.py +++ b/test/test_syntax.py @@ -7,7 +7,7 @@ import pytest from funsor import ops -from funsor.syntax import rewrite_ops +from funsor.syntax import alpha_rename, rewrite_ops python_version = tuple(map(int, sys.version.split()[0].split(".")[:2])) @@ -16,6 +16,7 @@ def assert_fn_eq(actual, expected): assert actual.__name__ == expected.__name__ assert actual.__doc__ == expected.__doc__ assert actual.__module__ == expected.__module__ + assert actual.__closure__ == expected.__closure__ assert actual.__code__.co_code == expected.__code__.co_code @@ -106,3 +107,41 @@ def foo(add_op, mul_op, x, y, z): args = (ops.mul, ops.mul, 1.23, 4.56, 7.89) assert actual(*args) == expected(*args) + + +def test_alpha_rename_1(): + @alpha_rename + def fn(a, b, /, c, d=1): + e = a + b + return c + d + e + + actual = fn + + def fn(_bound_0, _bound_1, /, c, d=1): + e = _bound_0 + _bound_1 + return c + d + e + + expected = fn + assert_fn_eq(actual, expected) + + assert actual(1, 2, 3, 4) == expected(1, 2, 3, 4) + + +@pytest.mark.xfail(reason="failure to reproduce enclosing scope") +def test_alpha_rename_2(): + @alpha_rename(locals_=locals()) + def fn(a, b, /, c, d=1): + e = a + b + return c + d + e + f + + actual = fn + + def fn(_bound_0, _bound_1, /, c, d=1): + e = _bound_0 + _bound_1 + return c + d + e + f + + expected = fn + assert_fn_eq(actual, expected) + + f = 0 + assert actual(1, 2, 3, 4) == expected(1, 2, 3, 4)