Skip to content

Commit

Permalink
Allow do interventions to reference intervened variable
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 17, 2023
1 parent 110dfd9 commit 15c88e8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
12 changes: 11 additions & 1 deletion pymc_experimental/model_transform/conditioning.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import warnings
from typing import Any, List, Mapping, Optional, Sequence, Union

from pymc import Model
from pymc.logprob.transforms import RVTransform
from pymc.pytensorf import _replace_vars_in_graphs
from pymc.util import get_transformed_name, get_untransformed_name
from pytensor.graph import ancestors
from pytensor.tensor import TensorVariable

from pymc_experimental.model_transform.basic import (
Expand Down Expand Up @@ -199,7 +201,15 @@ def do(
# Just a sanity check
assert model_var in fgraph.variables

intervention.name = model_var.name
# If the intervention references the original variable we must give it a different name
if model_var in ancestors([intervention]):
intervention.name = f"do_{model_var.name}"
warnings.warn(
f"Intervention expression references the variable that is being intervened: {model_var.name}. "
f"Intervention will be given the name: {intervention.name}"
)
else:
intervention.name = model_var.name
dims = extract_dims(model_var)
# If there are any RVs in the graph we introduce the intervention as a deterministic
if rvs_in_graph([intervention]):
Expand Down
17 changes: 17 additions & 0 deletions pymc_experimental/tests/model_transform/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,23 @@ def test_do_prune(prune):
assert set(do_m.named_vars) == orig_named_vars


def test_do_self_reference():
"""Check we can replace a variable by an expression that refers to the same variable."""
with pm.Model() as m:
x = pm.Normal("x", 0, 1)

with pytest.warns(
UserWarning,
match="Intervention expression references the variable that is being intervened",
):
new_m = do(m, {x: x + 100})

x = new_m["x"]
do_x = new_m["do_x"]
draw_x, draw_do_x = pm.draw([x, do_x], draws=5)
np.testing.assert_allclose(draw_x + 100, draw_do_x)


def test_change_value_transforms():
with pm.Model() as base_m:
p = pm.Uniform("p", 0, 1, transform=None)
Expand Down

0 comments on commit 15c88e8

Please sign in to comment.