From 95a4ceb71786a50cfff9e658fa3daeab7833622b Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sat, 6 May 2023 00:58:49 -0500 Subject: [PATCH] Preserve the identities of valued/observed variables --- aeppl/joint_logprob.py | 16 +++------- aeppl/rewriting.py | 55 +++++++++++++++++++++------------ aeppl/transforms.py | 36 ++++++++++----------- tests/test_composite_logprob.py | 17 ++++------ tests/test_convolutions.py | 6 ++-- tests/test_mixture.py | 6 ++-- 6 files changed, 70 insertions(+), 66 deletions(-) diff --git a/aeppl/joint_logprob.py b/aeppl/joint_logprob.py index a11e2ef8..3d6b78f4 100644 --- a/aeppl/joint_logprob.py +++ b/aeppl/joint_logprob.py @@ -141,18 +141,16 @@ def conditional_logprob( # maps to the logprob graphs and value variables before returning them. rv_values = {**original_rv_values, **realized} - fgraph, _, memo = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter) - - if extra_rewrites is not None: - extra_rewrites.add_requirements(fgraph, rv_values, memo) - extra_rewrites.apply(fgraph) + fgraph, new_rv_values = construct_ir_fgraph( + rv_values, ir_rewriter=ir_rewriter, extra_rewrites=extra_rewrites + ) # We assign log-densities on a per-node basis, and not per-output/variable. realized_vars = set() new_to_old_rvs = {} nodes_to_vals: Dict["Apply", List[Tuple["Variable", "Variable"]]] = {} - for bnd_var, (old_mvar, old_val) in zip(fgraph.outputs, rv_values.items()): + for bnd_var, (old_mvar, val) in zip(fgraph.outputs, new_rv_values.items()): mnode = bnd_var.owner assert mnode and isinstance(mnode.op, ValuedVariable) @@ -165,11 +163,7 @@ def conditional_logprob( if old_mvar in realized: realized_vars.add(rv_var) - # Do this just in case a value variable was changed. (Some transforms - # do this.) - new_val = memo[old_val] - - nodes_to_vals.setdefault(rv_node, []).append((val_var, new_val)) + nodes_to_vals.setdefault(rv_node, []).append((val_var, val)) new_to_old_rvs[rv_var] = old_mvar diff --git a/aeppl/rewriting.py b/aeppl/rewriting.py index da3b8ad6..ef510b85 100644 --- a/aeppl/rewriting.py +++ b/aeppl/rewriting.py @@ -1,11 +1,16 @@ -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Union import aesara.tensor as at from aesara.compile.mode import optdb from aesara.graph.basic import Apply, Variable from aesara.graph.features import Feature from aesara.graph.fg import FunctionGraph -from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter +from aesara.graph.rewriting.basic import ( + GraphRewriter, + NodeRewriter, + in2out, + node_rewriter, +) from aesara.graph.rewriting.db import EquilibriumDB, RewriteDatabaseQuery, SequenceDB from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.extra_ops import BroadcastTo @@ -180,9 +185,10 @@ def incsubtensor_rv_replace(fgraph, node): def construct_ir_fgraph( - rv_values: Dict[Variable, Variable], + rvs_to_values: Dict[Variable, Variable], ir_rewriter: Optional[GraphRewriter] = None, -) -> Tuple[FunctionGraph, Dict[Variable, Variable], Dict[Variable, Variable]]: + extra_rewrites: Optional[Union[GraphRewriter, NodeRewriter]] = None, +) -> Tuple[FunctionGraph, Dict[Variable, Variable]]: r"""Construct a `FunctionGraph` in measurable IR form for the keys in `rv_values`. A custom IR rewriter can be specified. By default, @@ -215,9 +221,8 @@ def construct_ir_fgraph( Returns ------- A `FunctionGraph` of the measurable IR, a copy of `rv_values` containing - the new, cloned versions of the original variables in `rv_values`, and - a ``dict`` mapping all the original variables to their cloned values in - the `FunctionGraph`. + the new, cloned versions of the original variables in `rv_values`. + """ # We're going to create a `FunctionGraph` that effectively represents the @@ -233,16 +238,20 @@ def construct_ir_fgraph( # so that they're distinct nodes in the graph. This allows us to replace # all instances of the original random variables with their value # variables, while leaving the output clones untouched. - rv_value_clones = {} + rv_clone_to_value_clone = {} + rv_to_value_clone = {} + value_clone_to_value = {} measured_outputs = {} - memo = {} - for rv, val in rv_values.items(): + memo: Dict[Variable, Variable] = {} + for rv, val in rvs_to_values.items(): rv_node_clone = rv.owner.clone() rv_clone = rv_node_clone.outputs[rv.owner.outputs.index(rv)] - rv_value_clones[rv_clone] = val - measured_outputs[rv] = valued_variable(rv_clone, val) - # Prevent value variables from being cloned - memo[val] = val + val_clone = val.clone() + val_clone.name = "val_clone" + rv_clone_to_value_clone[rv_clone] = val_clone + rv_to_value_clone[rv] = val_clone + value_clone_to_value[val_clone] = val + measured_outputs[rv] = valued_variable(rv_clone, val_clone) # We add `ShapeFeature` because it will get rid of references to the old # `RandomVariable`s that have been lifted; otherwise, it will be difficult @@ -257,9 +266,6 @@ def construct_ir_fgraph( copy_inputs=False, ) - # Update `rv_values` so that it uses the new cloned variables - rv_value_clones = {memo[k]: v for k, v in rv_value_clones.items()} - # Replace valued non-output variables with their values fgraph.replace_all( [(memo[rv], val) for rv, val in measured_outputs.items() if rv in memo], @@ -272,11 +278,22 @@ def construct_ir_fgraph( ir_rewriter.rewrite(fgraph) + if extra_rewrites is not None: + # Expect `value_clone_to_value` to be updated in-place + extra_rewrites.add_requirements(fgraph, rv_to_value_clone, value_clone_to_value) + extra_rewrites.apply(fgraph) + # Undo un-valued measurable IR rewrites new_to_old = tuple((v, k) for k, v in fgraph.measurable_conversions.items()) - fgraph.replace_all(new_to_old, reason="undo-unvalued-measurables") + # and add the original value variables back in + new_to_old += tuple(value_clone_to_value.items()) + fgraph.replace_all( + new_to_old, reason="undo-unvalued-measurables", import_missing=True + ) + + new_rvs_to_values = dict(zip(rvs_to_values.keys(), value_clone_to_value.values())) - return fgraph, rv_value_clones, memo + return fgraph, new_rvs_to_values @register_useless diff --git a/aeppl/transforms.py b/aeppl/transforms.py index 73170a3d..4dada3f6 100644 --- a/aeppl/transforms.py +++ b/aeppl/transforms.py @@ -161,14 +161,10 @@ def transform_values(fgraph: FunctionGraph, node: Apply): 4. Replace the old `ValuedVariable` with a new one containing a `TransformedVariable` value. - Step 3. is currently accomplished by updating the `memo` dictionary - associated with the `FunctionGraph`. Our main entry-point, + Step 3. is currently accomplished by updating the `rvs_to_values` + dictionary associated with the `FunctionGraph`. Our main entry-point, `conditional_logprob`, checks this dictionary for value variable changes. - TODO: This approach is less than ideal, because it puts awkward demands on - users/callers of this rewrite to check with `memo`; let's see if we can do - something better. - The new value variable mentioned in Step 2. may be of a different `Type` (e.g. extra/fewer dimensions) than the original value variable; this is why we must replace the corresponding original value variables before we @@ -235,8 +231,8 @@ def transform_values(fgraph: FunctionGraph, node: Apply): # This effectively lets the caller know that a value variable has been # replaced (i.e. they should filter all their old value variables through - # the memo/replacements map). - fgraph.memo[value_var] = trans_value_var + # the replacements map). + fgraph.value_clone_to_value[value_var] = trans_value_var trans_var = trans_node.outputs[rv_var_out_idx] new_var = valued_variable(trans_var, untrans_value_var) @@ -252,7 +248,7 @@ class TransformValuesMapping(Feature): """ - def __init__(self, values_to_transforms, memo): + def __init__(self, values_to_transforms, value_clone_to_value): """ Parameters ========== @@ -261,20 +257,19 @@ def __init__(self, values_to_transforms, memo): value variable can be assigned one of `RVTransform`, `DEFAULT_TRANSFORM`, or ``None``. Random variables with no transform specified remain unchanged. - memo - Mapping from variables to their clones. This is updated - in-place whenever a value variable is transformed. - + value_clone_to_value + Mapping between random variable value clones and their original + value variables. """ self.values_to_transforms = values_to_transforms - self.memo = memo + self.value_clone_to_value = value_clone_to_value def on_attach(self, fgraph): if hasattr(fgraph, "values_to_transforms"): raise AlreadyThere() fgraph.values_to_transforms = self.values_to_transforms - fgraph.memo = self.memo + fgraph.value_clone_to_value = self.value_clone_to_value class TransformValuesRewrite(GraphRewriter): @@ -322,6 +317,7 @@ def __init__( measurable variable can be assigned an `RVTransform` instance, `DEFAULT_TRANSFORM`, or ``None``. Measurable variables with no transform specified remain unchanged. + rvs_to_values """ @@ -330,14 +326,16 @@ def __init__( def add_requirements( self, fgraph, - rv_to_values: Dict[TensorVariable, TensorVariable], - memo: Dict[TensorVariable, TensorVariable], + rvs_to_values: Dict[TensorVariable, TensorVariable], + value_clone_to_value: Dict[TensorVariable, TensorVariable], ): values_to_transforms = { - rv_to_values[rv]: transform + rvs_to_values[rv]: transform for rv, transform in self.rvs_to_transforms.items() } - values_transforms_feature = TransformValuesMapping(values_to_transforms, memo) + values_transforms_feature = TransformValuesMapping( + values_to_transforms, value_clone_to_value + ) fgraph.attach_feature(values_transforms_feature) def apply(self, fgraph: FunctionGraph): diff --git a/tests/test_composite_logprob.py b/tests/test_composite_logprob.py index 43770817..69ed7f6a 100644 --- a/tests/test_composite_logprob.py +++ b/tests/test_composite_logprob.py @@ -79,25 +79,20 @@ def test_unvalued_ir_reversion(): """Make sure that un-valued IR rewrites are reverted.""" srng = at.random.RandomStream(0) - x_rv = srng.normal() + x_rv = srng.normal(name="X") y_rv = at.clip(x_rv, 0, 1) - z_rv = srng.normal(y_rv, 1, name="z") + y_rv.name = "Y" + z_rv = srng.normal(y_rv, 1, name="Z") z_vv = z_rv.clone() + z_vv.name = "z" # Only the `z_rv` is "valued", so `y_rv` doesn't need to be converted into # measurable IR. rv_values = {z_rv: z_vv} - z_fgraph, _, memo = construct_ir_fgraph(rv_values) + z_fgraph, new_rvs_to_values = construct_ir_fgraph(rv_values) - assert memo[y_rv] in z_fgraph.measurable_conversions - - measurable_y_rv = z_fgraph.measurable_conversions[memo[y_rv]] - assert isinstance(measurable_y_rv.owner.op, MeasurableClip) - - # `construct_ir_fgraph` should've reverted the un-valued measurable IR - # change - assert measurable_y_rv not in z_fgraph + assert not any(isinstance(node.op, MeasurableClip) for node in z_fgraph.apply_nodes) def test_shifted_cumsum(): diff --git a/tests/test_convolutions.py b/tests/test_convolutions.py index bbbc057f..40a99785 100644 --- a/tests/test_convolutions.py +++ b/tests/test_convolutions.py @@ -74,7 +74,7 @@ def test_add_independent_normals(mu_x, mu_y, sigma_x, sigma_y, x_shape, y_shape, Z_rv.name = "Z" z_vv = Z_rv.clone() - fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv}) + fgraph, *_ = construct_ir_fgraph({Z_rv: z_vv}) (valued_var_out_node) = fgraph.outputs[0].owner # The convolution should be applied, and not the transform @@ -108,7 +108,7 @@ def test_normal_add_input_valued(): Z_rv.name = "Z" z_vv = Z_rv.clone() - fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv, X_rv: x_vv}) + fgraph, *_ = construct_ir_fgraph({Z_rv: z_vv, X_rv: x_vv}) valued_var_out_node = fgraph.outputs[0].owner # We should not expect the convolution to be applied; instead, the @@ -136,7 +136,7 @@ def test_normal_add_three_inputs(): Z_rv.name = "Z" z_vv = Z_rv.clone() - fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv}) + fgraph, *_ = construct_ir_fgraph({Z_rv: z_vv}) valued_var_out_node = fgraph.outputs[0].owner # The convolution should be applied, and not the transform diff --git a/tests/test_mixture.py b/tests/test_mixture.py index 4bf86838..468742c2 100644 --- a/tests/test_mixture.py +++ b/tests/test_mixture.py @@ -685,7 +685,7 @@ def test_switch_mixture(): z_vv = Z1_rv.clone() z_vv.name = "z1" - fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) + fgraph, *_ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) out_rv = fgraph.outputs[0].owner.inputs[0] assert isinstance(out_rv.owner.op, MixtureRV) @@ -696,7 +696,7 @@ def test_switch_mixture(): Z1_rv.name = "Z1" - fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) + fgraph, *_ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) out_rv = fgraph.outputs[0].owner.inputs[0] assert out_rv.name == "Z1-mixture" @@ -705,7 +705,7 @@ def test_switch_mixture(): Z2_rv = at.stack((X_rv, Y_rv))[I_rv] - fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv}) + fgraph2, *_ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv}) assert equal_computations(fgraph.outputs, fgraph2.outputs)