diff --git a/funsor/adjoint.py b/funsor/adjoint.py index dcd18d530..c9db34279 100644 --- a/funsor/adjoint.py +++ b/funsor/adjoint.py @@ -3,22 +3,28 @@ from collections import defaultdict from collections.abc import Hashable +from functools import reduce from funsor.cnf import Contraction, nullop +from funsor.domains import Reals from funsor.interpretations import Interpretation, reflect from funsor.interpreter import stack_reinterpret from funsor.ops import AssociativeOp from funsor.registry import KeyedRegistry +from funsor.sum_product import MarkovProduct from funsor.terms import ( Binary, Cat, + eager, Funsor, + Lambda, Reduce, Scatter, Slice, Subs, substitute, to_funsor, + Variable, ) from . import instrument, interpreter, ops @@ -65,12 +71,12 @@ def __enter__(self): self._old_interpretation = interpreter.get_interpretation() return super().__enter__() - def adjoint(self, sum_op, bin_op, root, targets=None): + def adjoint(self, sum_op, bin_op, root, targets=None, out_adj=None): zero = to_funsor(ops.UNITS[sum_op]) one = to_funsor(ops.UNITS[bin_op]) adjoint_values = defaultdict(lambda: zero) - adjoint_values[root] = one + adjoint_values[root] = out_adj or one reached_root = False while self.tape: @@ -127,11 +133,11 @@ def adjoint(self, sum_op, bin_op, root, targets=None): return {target: result[target] for target in targets} -def adjoint(sum_op, bin_op, expr): +def adjoint(sum_op, bin_op, expr, out_adj=None): with AdjointTape() as tape: # TODO fix traversal order in AdjointTape instead of using stack_reinterpret root = stack_reinterpret(expr) - return tape.adjoint(sum_op, bin_op, root) + return tape.adjoint(sum_op, bin_op, root, out_adj=out_adj) # logaddexp/add @@ -147,6 +153,21 @@ def _fail_default(*args): ) +@adjoint_ops.register( + MarkovProduct, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, AssociativeOp, Funsor, Variable, frozenset, frozenset +) +def adjoint_markovproduct(adj_sum_op, adj_bin_op, out_adj, sum_op, prod_op, trans, time, step, step_names): + input_vars = tuple(Variable(key, value) for key, value in trans.inputs.items()) + trans_bound = reduce(lambda x, y: Lambda(y, x), input_vars, trans) + trans_placeholder = Variable("__trans", trans_bound.output)[tuple(trans.inputs)] + # trans_placeholder = Variable("__trans", Reals[trans.data.shape])[tuple(trans.inputs)] + with eager: + expr = MarkovProduct(sum_op, prod_op, trans_placeholder, time, step, step_names) + bwd_expr = adjoint(adj_sum_op, adj_bin_op, expr, out_adj=out_adj)[trans_placeholder] + trans_adj = bwd_expr(__trans=trans_bound) + return ((trans, trans_adj),) + + @adjoint_ops.register( Binary, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, Funsor, Funsor ) diff --git a/test/test_adjoint.py b/test/test_adjoint.py index d4fc2696a..08e681f8e 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -209,6 +209,7 @@ def test_optimized_plated_einsum_adjoint(equation, plates, backend): @pytest.mark.parametrize( "impl", [ + MarkovProduct, sequential_sum_product, naive_sequential_sum_product, xfail_param(MarkovProduct, reason="mysteriously doubles adjoint values?"), @@ -227,31 +228,34 @@ def test_sequential_sum_product_adjoint( trans = random_tensor(inputs) time = Variable("time", Bint[num_steps]) - with AdjointTape() as actual_tape: - actual = impl(sum_op, prod_op, trans, time, {"prev": "curr"}) - actual = actual.reduce(sum_op) + with funsor.terms.lazy: + with AdjointTape() as actual_tape: + actual = impl(sum_op, prod_op, trans, time, {"prev": "curr"}) + actual = actual.reduce(sum_op) - # Check against contract. - operands = tuple( - trans(time=t, prev="t_{}".format(t), curr="t_{}".format(t + 1)) - for t in range(num_steps) - ) - reduce_vars = frozenset("t_{}".format(t) for t in range(1, num_steps)) - with AdjointTape() as expected_tape: - with reflect: - lazy_expected = sum_product(sum_op, prod_op, operands, reduce_vars) - expected = apply_optimizer(lazy_expected) - expected = expected.reduce(sum_op) + # Check against contract. + operands = tuple( + trans(time=t, prev="t_{}".format(t), curr="t_{}".format(t + 1)) + for t in range(num_steps) + ) + reduce_vars = frozenset("t_{}".format(t) for t in range(1, num_steps)) + with AdjointTape() as expected_tape: + # with reflect: + expected = sum_product(sum_op, prod_op, operands, reduce_vars) + # expected = apply_optimizer(lazy_expected) + expected = expected.reduce(sum_op) + # perform backward passes only after the sanity check + expected_bwds = expected_tape.adjoint(sum_op, prod_op, expected, operands) + actual_bwd = actual_tape.adjoint(sum_op, prod_op, actual, (trans,))[trans] + + actual = apply_optimizer(actual) + expected = apply_optimizer(expected) # check forward pass (sanity check) assert_close( actual, expected.align(tuple(actual.inputs.keys())), rtol=5e-3 * num_steps ) - # perform backward passes only after the sanity check - expected_bwds = expected_tape.adjoint(sum_op, prod_op, expected, operands) - actual_bwd = actual_tape.adjoint(sum_op, prod_op, actual, (trans,))[trans] - # check backward pass for t, operand in enumerate(operands): actual_bwd_t = actual_bwd(