From d7f5ab74d53c75a10daab8dbc1c01c1199867482 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 14 Sep 2023 10:56:13 +0200 Subject: [PATCH] Remove fgraph functionality that graduated to PyMC --- docs/api_reference.rst | 17 +- pymc_experimental/model_transform/basic.py | 46 -- .../model_transform/conditioning.py | 361 +--------------- .../tests/model_transform/test_basic.py | 19 - .../model_transform/test_conditioning.py | 298 +------------ .../tests/utils/test_model_fgraph.py | 342 +-------------- pymc_experimental/utils/__init__.py | 2 - pymc_experimental/utils/model_fgraph.py | 393 +----------------- 8 files changed, 39 insertions(+), 1439 deletions(-) delete mode 100644 pymc_experimental/model_transform/basic.py delete mode 100644 pymc_experimental/tests/model_transform/test_basic.py diff --git a/docs/api_reference.rst b/docs/api_reference.rst index f3f8b253..54a920bd 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -35,19 +35,6 @@ Distributions histogram_approximation -Model Transformations -===================== - -.. currentmodule:: pymc_experimental.model_transform -.. autosummary:: - :toctree: generated/ - - conditioning.do - conditioning.observe - conditioning.change_value_transforms - conditioning.remove_value_transforms - - Utils ===== @@ -55,11 +42,9 @@ Utils .. autosummary:: :toctree: generated/ - clone_model spline.bspline_interpolation prior.prior_from_idata - model_fgraph.fgraph_from_model - model_fgraph.model_from_fgraph + Statespace Models ================= diff --git a/pymc_experimental/model_transform/basic.py b/pymc_experimental/model_transform/basic.py deleted file mode 100644 index 8146f8fe..00000000 --- a/pymc_experimental/model_transform/basic.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import List, Sequence, Union - -from pymc import Model -from pytensor import Variable -from pytensor.graph import ancestors - -from pymc_experimental.utils.model_fgraph import ( - ModelObservedRV, - ModelVar, - fgraph_from_model, - model_from_fgraph, -) - -ModelVariable = Union[Variable, str] - - -def prune_vars_detached_from_observed(model: Model) -> Model: - """Prune model variables that are not related to any observed variable in the Model.""" - - # Potentials are ambiguous as whether they correspond to likelihood or prior terms, - # We simply raise for now - if model.potentials: - raise NotImplementedError("Pruning not implemented for models with Potentials") - - fgraph, _ = fgraph_from_model(model, inlined_views=True) - observed_vars = ( - out - for node in fgraph.apply_nodes - if isinstance(node.op, ModelObservedRV) - for out in node.outputs - ) - ancestor_nodes = {var.owner for var in ancestors(observed_vars)} - nodes_to_remove = { - node - for node in fgraph.apply_nodes - if isinstance(node.op, ModelVar) and node not in ancestor_nodes - } - for node_to_remove in nodes_to_remove: - fgraph.remove_node(node_to_remove) - return model_from_fgraph(fgraph) - - -def parse_vars(model: Model, vars: Union[ModelVariable, Sequence[ModelVariable]]) -> List[Variable]: - if not isinstance(vars, (list, tuple)): - vars = (vars,) - return [model[var] if isinstance(var, str) else var for var in vars] diff --git a/pymc_experimental/model_transform/conditioning.py b/pymc_experimental/model_transform/conditioning.py index c0b5a7b7..3aa8774c 100644 --- a/pymc_experimental/model_transform/conditioning.py +++ b/pymc_experimental/model_transform/conditioning.py @@ -1,358 +1,11 @@ +# pylint: disable=unused-import import warnings -from typing import Any, List, Mapping, Optional, Sequence, Union -from pymc import Model -from pymc.logprob.transforms import RVTransform -from pymc.pytensorf import _replace_vars_in_graphs -from pymc.util import get_transformed_name, get_untransformed_name -from pytensor.graph import ancestors -from pytensor.tensor import TensorVariable - -from pymc_experimental.model_transform.basic import ( - ModelVariable, - parse_vars, - prune_vars_detached_from_observed, -) -from pymc_experimental.utils.model_fgraph import ( - ModelDeterministic, - ModelFreeRV, - extract_dims, - fgraph_from_model, - model_deterministic, - model_free_rv, - model_from_fgraph, - model_named, - model_observed_rv, - toposort_replace, +from pymc.model.transform.conditioning import ( + change_value_transforms, + do, + observe, + remove_value_transforms, ) -from pymc_experimental.utils.pytensorf import rvs_in_graph - - -def observe( - model: Model, vars_to_observations: Mapping[Union["str", TensorVariable], Any] -) -> Model: - """Convert free RVs or Deterministics to observed RVs. - - Parameters - ---------- - model: PyMC Model - vars_to_observations: Dict of variable or name to TensorLike - Dictionary that maps model variables (or names) to observed values. - Observed values must have a shape and data type that is compatible - with the original model variable. - - Returns - ------- - new_model: PyMC model - A distinct PyMC model with the relevant variables observed. - All remaining variables are cloned and can be retrieved via `new_model["var_name"]`. - - Examples - -------- - - .. code-block:: python - - import pymc as pm - from pymc_experimental.model_transform.conditioning import observe - - with pm.Model() as m: - x = pm.Normal("x") - y = pm.Normal("y", x) - z = pm.Normal("z", y) - - m_new = observe(m, {y: 0.5}) - - Deterministic variables can also be observed. - This relies on PyMC ability to infer the logp of the underlying expression - - .. code-block:: python - - import pymc as pm - from pymc_experimental.model_transform.conditioning import observe - - with pm.Model() as m: - x = pm.Normal("x") - y = pm.Normal.dist(x, shape=(5,)) - y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1)) - - new_m = observe(m, {y_censored: [0.9, 0.5, 0.3, 1, 1]}) - - - """ - vars_to_observations = { - model[var] if isinstance(var, str) else var: obs - for var, obs in vars_to_observations.items() - } - - valid_model_vars = set(model.free_RVs + model.deterministics) - if any(var not in valid_model_vars for var in vars_to_observations): - raise ValueError(f"At least one var is not a free variable or deterministic in the model") - - fgraph, memo = fgraph_from_model(model) - - replacements = {} - for var, obs in vars_to_observations.items(): - model_var = memo[var] - - # Just a sanity check - assert isinstance(model_var.owner.op, (ModelFreeRV, ModelDeterministic)) - assert model_var in fgraph.variables - - var = model_var.owner.inputs[0] - var.name = model_var.name - dims = extract_dims(model_var) - model_obs_rv = model_observed_rv(var, var.type.filter_variable(obs), *dims) - replacements[model_var] = model_obs_rv - - toposort_replace(fgraph, tuple(replacements.items())) - - return model_from_fgraph(fgraph) - - -def replace_vars_in_graphs(graphs: Sequence[TensorVariable], replacements) -> List[TensorVariable]: - def replacement_fn(var, inner_replacements): - if var in replacements: - inner_replacements[var] = replacements[var] - - # Handle root inputs as those will never be passed to the replacement_fn - for inp in var.owner.inputs: - if inp.owner is None and inp in replacements: - inner_replacements[inp] = replacements[inp] - - return [var] - - replaced_graphs, _ = _replace_vars_in_graphs(graphs=graphs, replacement_fn=replacement_fn) - return replaced_graphs - - -def do( - model: Model, - vars_to_interventions: Mapping[Union["str", TensorVariable], Any], - prune_vars=False, -) -> Model: - """Replace model variables by intervention variables. - - Intervention variables will either show up as `Data` or `Deterministics` in the new model, - depending on whether they depend on other RandomVariables or not. - - Parameters - ---------- - model: PyMC Model - vars_to_interventions: Dict of variable or name to TensorLike - Dictionary that maps model variables (or names) to intervention expressions. - Intervention expressions must have a shape and data type that is compatible - with the original model variable. - prune_vars: bool, defaults to False - Whether to prune model variables that are not connected to any observed variables, - after the interventions. - - Returns - ------- - new_model: PyMC model - A distinct PyMC model with the relevant variables replaced by the intervention expressions. - All remaining variables are cloned and can be retrieved via `new_model["var_name"]`. - - Examples - -------- - - .. code-block:: python - - import pymc as pm - from pymc_experimental.model_transform.conditioning import do - - with pm.Model() as m: - x = pm.Normal("x", 0, 1) - y = pm.Normal("y", x, 1) - z = pm.Normal("z", y + x, 1) - - # Dummy posterior, same as calling `pm.sample` - idata_m = az.from_dict({rv.name: [pm.draw(rv, draws=500)] for rv in [x, y, z]}) - - # Replace `y` by a constant `100.0` - m_do = do(m, {y: 100.0}) - with m_do: - idata_do = pm.sample_posterior_predictive(idata_m, var_names="z") - - """ - do_mapping = {} - for var, obs in vars_to_interventions.items(): - if isinstance(var, str): - var = model[var] - try: - do_mapping[var] = var.type.filter_variable(obs) - except TypeError as err: - raise TypeError( - "Incompatible replacement type. Make sure the shape and datatype of the interventions match the original variables" - ) from err - - if any(var not in model.named_vars.values() for var in do_mapping): - raise ValueError(f"At least one var is not a named variable in the model") - - fgraph, memo = fgraph_from_model(model, inlined_views=True) - - # We need the interventions defined in terms of the IR fgraph representation, - # In case they reference other variables in the model - ir_interventions = replace_vars_in_graphs(list(do_mapping.values()), replacements=memo) - - replacements = {} - for var, intervention in zip(do_mapping, ir_interventions): - model_var = memo[var] - - # Just a sanity check - assert model_var in fgraph.variables - - # If the intervention references the original variable we must give it a different name - if model_var in ancestors([intervention]): - intervention.name = f"do_{model_var.name}" - warnings.warn( - f"Intervention expression references the variable that is being intervened: {model_var.name}. " - f"Intervention will be given the name: {intervention.name}" - ) - else: - intervention.name = model_var.name - dims = extract_dims(model_var) - # If there are any RVs in the graph we introduce the intervention as a deterministic - if rvs_in_graph([intervention]): - new_var = model_deterministic(intervention.copy(name=intervention.name), *dims) - # Otherwise as a named variable (Constant or Shared data) - else: - new_var = model_named(intervention, *dims) - - replacements[model_var] = new_var - - # Replace variables by interventions - toposort_replace(fgraph, tuple(replacements.items())) - - model = model_from_fgraph(fgraph) - if prune_vars: - return prune_vars_detached_from_observed(model) - return model - - -def change_value_transforms( - model: Model, - vars_to_transforms: Mapping[ModelVariable, Union[RVTransform, None]], -) -> Model: - """Change the value variables transforms in the model - - Parameters - ---------- - model : Model - vars_to_transforms : Dict - Dictionary that maps RVs to new transforms to be applied to the respective value variables - - Returns - ------- - new_model : Model - Model with the updated transformed value variables - - Examples - -------- - Extract untransformed space Hessian after finding transformed space MAP - - .. code-block:: python - - import pymc as pm - from pymc.distributions.transforms import logodds - from pymc_experimental.model_transform.conditioning import change_value_transforms - - with pm.Model() as base_m: - p = pm.Uniform("p", 0, 1, transform=None) - w = pm.Binomial("w", n=9, p=p, observed=6) - - with change_value_transforms(base_m, {"p": logodds}) as transformed_p: - mean_q = pm.find_MAP() - - with change_value_transforms(transformed_p, {"p": None}) as untransformed_p: - new_p = untransformed_p['p'] - std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0] - - print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}") - # Mean, Standard deviation - # p 0.67, 0.16 - - """ - vars_to_transforms = { - parse_vars(model, var)[0]: transform for var, transform in vars_to_transforms.items() - } - - if set(vars_to_transforms.keys()) - set(model.free_RVs): - raise ValueError(f"All keys must be free variables in the model: {model.free_RVs}") - - fgraph, memo = fgraph_from_model(model) - - vars_to_transforms = {memo[var]: transform for var, transform in vars_to_transforms.items()} - replacements = {} - for node in fgraph.apply_nodes: - if not isinstance(node.op, ModelFreeRV): - continue - - [dummy_rv] = node.outputs - if dummy_rv not in vars_to_transforms: - continue - - transform = vars_to_transforms[dummy_rv] - - rv, value, *dims = node.inputs - - new_value = rv.type() - try: - untransformed_name = get_untransformed_name(value.name) - except ValueError: - untransformed_name = value.name - if transform: - new_name = get_transformed_name(untransformed_name, transform) - else: - new_name = untransformed_name - new_value.name = new_name - - new_dummy_rv = model_free_rv(rv, new_value, transform, *dims) - replacements[dummy_rv] = new_dummy_rv - - toposort_replace(fgraph, tuple(replacements.items())) - return model_from_fgraph(fgraph) - - -def remove_value_transforms( - model: Model, - vars: Optional[Sequence[ModelVariable]] = None, -) -> Model: - """Remove the value variables transforms in the model - - Parameters - ---------- - model : Model - vars : Model variables, optional - Model variables for which to remove transforms. Defaults to all transformed variables - - Returns - ------- - new_model : Model - Model with the removed transformed value variables - - Examples - -------- - Extract untransformed space Hessian after finding transformed space MAP - - .. code-block:: python - - import pymc as pm - from pymc_experimental.model_transform.conditioning import remove_value_transforms - - with pm.Model() as transformed_m: - p = pm.Uniform("p", 0, 1) - w = pm.Binomial("w", n=9, p=p, observed=6) - mean_q = pm.find_MAP() - - with remove_value_transforms(transformed_m) as untransformed_m: - new_p = untransformed_m["p"] - std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0] - print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}") - - # Mean, Standard deviation - # p 0.67, 0.16 - """ - if vars is None: - vars = model.free_RVs - return change_value_transforms(model, {var: None for var in vars}) +warnings.warn("The functionality in this module has been moved to PyMC") diff --git a/pymc_experimental/tests/model_transform/test_basic.py b/pymc_experimental/tests/model_transform/test_basic.py deleted file mode 100644 index a2771d01..00000000 --- a/pymc_experimental/tests/model_transform/test_basic.py +++ /dev/null @@ -1,19 +0,0 @@ -import pymc as pm - -from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed - - -def test_prune_vars_detached_from_observed(): - with pm.Model() as m: - obs_data = pm.MutableData("obs_data", 0) - a0 = pm.ConstantData("a0", 0) - a1 = pm.Normal("a1", a0) - a2 = pm.Normal("a2", a1) - pm.Normal("obs", a2, observed=obs_data) - - d0 = pm.ConstantData("d0", 0) - d1 = pm.Normal("d1", d0) - - assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"} - pruned_m = prune_vars_detached_from_observed(m) - assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"} diff --git a/pymc_experimental/tests/model_transform/test_conditioning.py b/pymc_experimental/tests/model_transform/test_conditioning.py index 6fcc8240..6b3cf105 100644 --- a/pymc_experimental/tests/model_transform/test_conditioning.py +++ b/pymc_experimental/tests/model_transform/test_conditioning.py @@ -1,296 +1,28 @@ -import arviz as az -import numpy as np -import pymc as pm +import pymc import pytest -from pymc.distributions.transforms import logodds -from pymc.variational.minibatch_rv import create_minibatch_rv -from pytensor import config -from pymc_experimental.model_transform.conditioning import ( - change_value_transforms, - do, - observe, - remove_value_transforms, -) - - -def test_observe(): - with pm.Model() as m_old: - x = pm.Normal("x") - y = pm.Normal("y", x) - z = pm.Normal("z", y) - - m_new = observe(m_old, {y: 0.5}) - - assert len(m_new.free_RVs) == 2 - assert len(m_new.observed_RVs) == 1 - assert m_new["x"] in m_new.free_RVs - assert m_new["y"] in m_new.observed_RVs - assert m_new["z"] in m_new.free_RVs - - np.testing.assert_allclose( - m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}), - m_new.compile_logp()({"x": 0.9, "z": 1.4}), - ) - - # Test two substitutions - m_new = observe(m_old, {y: 0.5, z: 1.4}) - - assert len(m_new.free_RVs) == 1 - assert len(m_new.observed_RVs) == 2 - assert m_new["x"] in m_new.free_RVs - assert m_new["y"] in m_new.observed_RVs - assert m_new["z"] in m_new.observed_RVs - - np.testing.assert_allclose( - m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}), - m_new.compile_logp()({"x": 0.9}), - ) - - -def test_observe_minibatch(): - data = np.zeros((100,), dtype=config.floatX) - batch_size = 10 - with pm.Model() as m_old: - x = pm.Normal("x") - y = pm.Normal("y", x) - # Minibatch RVs are usually created with `total_size` kwarg - z_raw = pm.Normal.dist(y, shape=batch_size) - mb_z = create_minibatch_rv(z_raw, total_size=data.shape) - m_old.register_rv(mb_z, name="mb_z") - - mb_data = pm.Minibatch(data, batch_size=batch_size) - m_new = observe(m_old, {mb_z: mb_data}) - - assert len(m_new.free_RVs) == 2 - assert len(m_new.observed_RVs) == 1 - assert m_new["x"] in m_new.free_RVs - assert m_new["y"] in m_new.free_RVs - assert m_new["mb_z"] in m_new.observed_RVs - - np.testing.assert_allclose( - m_old.compile_logp()({"x": 0.9, "y": 0.5, "mb_z": np.zeros(10)}), - m_new.compile_logp()({"x": 0.9, "y": 0.5}), - ) - - -def test_observe_deterministic(): - y_censored_obs = np.array([0.9, 0.5, 0.3, 1, 1], dtype=config.floatX) - - with pm.Model() as m_old: - x = pm.Normal("x") - y = pm.Normal.dist(x, shape=(5,)) - y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1)) - - m_new = observe(m_old, {y_censored: y_censored_obs}) - - with pm.Model() as m_ref: - x = pm.Normal("x") - pm.Censored("y_censored", pm.Normal.dist(x), lower=-1, upper=1, observed=y_censored_obs) - - -def test_observe_dims(): - with pm.Model(coords={"test_dim": range(5)}) as m_old: - x = pm.Normal("x", dims="test_dim") - - m_new = observe(m_old, {x: np.arange(5, dtype=config.floatX)}) - assert m_new.named_vars_to_dims["x"] == ["test_dim"] - - -def test_do(): - rng = np.random.default_rng(seed=435) - with pm.Model() as m_old: - x = pm.Normal("x", 0, 1e-3) - y = pm.Normal("y", x, 1e-3) - z = pm.Normal("z", y + x, 1e-3) - - assert -5 < pm.draw(z, random_seed=rng) < 5 - - m_new = do(m_old, {y: x + 100}) - - assert len(m_new.free_RVs) == 2 - assert m_new["x"] in m_new.free_RVs - assert m_new["y"] in m_new.deterministics - assert m_new["z"] in m_new.free_RVs - - assert 95 < pm.draw(m_new["z"], random_seed=rng) < 105 - - # Test two substitutions - with m_old: - switch = pm.MutableData("switch", 1) - m_new = do(m_old, {y: 100 * switch, x: 100 * switch}) - - assert len(m_new.free_RVs) == 1 - assert m_new["y"] not in m_new.deterministics - assert m_new["x"] not in m_new.deterministics - assert m_new["z"] in m_new.free_RVs - - assert 195 < pm.draw(m_new["z"], random_seed=rng) < 205 - with m_new: - pm.set_data({"switch": 0}) - assert -5 < pm.draw(m_new["z"], random_seed=rng) < 5 - - -def test_do_posterior_predictive(): - with pm.Model() as m: - x = pm.Normal("x", 0, 1) - y = pm.Normal("y", x, 1) - z = pm.Normal("z", y + x, 1e-3) - - # Dummy posterior - idata_m = az.from_dict( - { - "x": np.full((2, 500), 25), - "y": np.full((2, 500), np.nan), - "z": np.full((2, 500), np.nan), - } - ) - - # Replace `y` by a constant `100.0` - m_do = do(m, {y: 100.0}) - with m_do: - idata_do = pm.sample_posterior_predictive(idata_m, var_names="z") - - assert 120 < idata_do.posterior_predictive["z"].mean() < 130 - - -@pytest.mark.parametrize("mutable", (False, True)) -def test_do_constant(mutable): - rng = np.random.default_rng(seed=122) - with pm.Model() as m: - x = pm.Data("x", 0, mutable=mutable) - y = pm.Normal("y", x, 1e-3) - - do_m = do(m, {x: 105}) - assert pm.draw(do_m["y"], random_seed=rng) > 100 - - -def test_do_deterministic(): - rng = np.random.default_rng(seed=435) - with pm.Model() as m: - x = pm.Normal("x", 0, 1e-3) - y = pm.Deterministic("y", x + 105) - z = pm.Normal("z", y, 1e-3) - - do_m = do(m, {"z": x - 105}) - assert pm.draw(do_m["z"], random_seed=rng) < 100 - - -def test_do_dims(): - coords = {"test_dim": range(10)} - with pm.Model(coords=coords) as m: - x = pm.Normal("x", dims="test_dim") - y = pm.Deterministic("y", x + 5, dims="test_dim") - - do_m = do( - m, - {"x": np.zeros(10, dtype=config.floatX)}, - ) - assert do_m.named_vars_to_dims["x"] == ["test_dim"] - - do_m = do( - m, - {"y": np.zeros(10, dtype=config.floatX)}, - ) - assert do_m.named_vars_to_dims["y"] == ["test_dim"] - - -@pytest.mark.parametrize("prune", (False, True)) -def test_do_prune(prune): - - with pm.Model() as m: - x0 = pm.ConstantData("x0", 0) - x1 = pm.ConstantData("x1", 0) - y = pm.Normal("y") - y_det = pm.Deterministic("y_det", y + x0) - z = pm.Normal("z", y_det) - llike = pm.Normal("llike", z + x1, observed=0) - - orig_named_vars = {"x0", "x1", "y", "y_det", "z", "llike"} - assert set(m.named_vars) == orig_named_vars - - do_m = do(m, {y_det: x0 + 5}, prune_vars=prune) - if prune: - assert set(do_m.named_vars) == {"x0", "x1", "y_det", "z", "llike"} - else: - assert set(do_m.named_vars) == orig_named_vars - - do_m = do(m, {z: 0.5}, prune_vars=prune) - if prune: - assert set(do_m.named_vars) == {"x1", "z", "llike"} - else: - assert set(do_m.named_vars) == orig_named_vars - - -def test_do_self_reference(): - """Check we can replace a variable by an expression that refers to the same variable.""" - with pm.Model() as m: - x = pm.Normal("x", 0, 1) +def test_imports_from_pymc(): with pytest.warns( UserWarning, - match="Intervention expression references the variable that is being intervened", + match="The functionality in this module has been moved to PyMC", ): - new_m = do(m, {x: x + 100}) - - x = new_m["x"] - do_x = new_m["do_x"] - draw_x, draw_do_x = pm.draw([x, do_x], draws=5) - np.testing.assert_allclose(draw_x + 100, draw_do_x) - - -def test_change_value_transforms(): - with pm.Model() as base_m: - p = pm.Uniform("p", 0, 1, transform=None) - w = pm.Binomial("w", n=9, p=p, observed=6) - assert base_m.rvs_to_transforms[p] is None - assert base_m.rvs_to_values[p].name == "p" - - with change_value_transforms(base_m, {"p": logodds}) as transformed_p: - new_p = transformed_p["p"] - assert transformed_p.rvs_to_transforms[new_p] == logodds - assert transformed_p.rvs_to_values[new_p].name == "p_logodds__" - mean_q = pm.find_MAP(progressbar=False) - - with change_value_transforms(transformed_p, {"p": None}) as untransformed_p: - new_p = untransformed_p["p"] - assert untransformed_p.rvs_to_transforms[new_p] is None - assert untransformed_p.rvs_to_values[new_p].name == "p" - std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0] - - np.testing.assert_allclose(np.round(mean_q["p"], 2), 0.67) - np.testing.assert_allclose(np.round(std_q[0], 2), 0.16) - - -def test_change_value_transforms_error(): - with pm.Model() as m: - x = pm.Uniform("x", observed=5.0) + from pymc_experimental.model_transform.conditioning import do as fn - with pytest.raises(ValueError, match="All keys must be free variables in the model"): - change_value_transforms(m, {x: logodds}) + assert fn is pymc.do + from pymc_experimental.model_transform.conditioning import observe as fn -def test_remove_value_transforms(): - with pm.Model() as base_m: - p = pm.Uniform("p", transform=logodds) - q = pm.Uniform("q", transform=logodds) + assert fn is pymc.observe - new_m = remove_value_transforms(base_m) - new_p = new_m["p"] - new_q = new_m["q"] - assert new_m.rvs_to_transforms == {new_p: None, new_q: None} + from pymc_experimental.model_transform.conditioning import ( + change_value_transforms as fn, + ) - new_m = remove_value_transforms(base_m, [p, q]) - new_p = new_m["p"] - new_q = new_m["q"] - assert new_m.rvs_to_transforms == {new_p: None, new_q: None} + assert fn is pymc.model.transform.conditioning.change_value_transforms - new_m = remove_value_transforms(base_m, [p]) - new_p = new_m["p"] - new_q = new_m["q"] - assert new_m.rvs_to_transforms == {new_p: None, new_q: logodds} + from pymc_experimental.model_transform.conditioning import ( + remove_value_transforms as fn, + ) - new_m = remove_value_transforms(base_m, ["q"]) - new_p = new_m["p"] - new_q = new_m["q"] - assert new_m.rvs_to_transforms == {new_p: logodds, new_q: None} + assert fn is pymc.model.transform.conditioning.remove_value_transforms diff --git a/pymc_experimental/tests/utils/test_model_fgraph.py b/pymc_experimental/tests/utils/test_model_fgraph.py index f5c38edf..0ef3ef88 100644 --- a/pymc_experimental/tests/utils/test_model_fgraph.py +++ b/pymc_experimental/tests/utils/test_model_fgraph.py @@ -1,338 +1,20 @@ -import numpy as np -import pymc as pm -import pytensor.tensor as pt +import pymc import pytest -from pytensor import config, shared -from pytensor.graph import Constant, FunctionGraph, node_rewriter -from pytensor.graph.rewriting.basic import in2out -from pytensor.tensor.exceptions import NotScalarConstantError -from pymc_experimental.utils.model_fgraph import ( - ModelDeterministic, - ModelFreeRV, - ModelNamed, - ModelObservedRV, - ModelPotential, - ModelVar, - clone_model, - fgraph_from_model, - model_deterministic, - model_free_rv, - model_from_fgraph, -) +def test_imports_from_pymc(): + with pytest.warns( + UserWarning, + match="The functionality in this module has been moved to PyMC", + ): + from pymc_experimental.utils.model_fgraph import fgraph_from_model as fn -def test_basic(): - """Test we can convert from a PyMC Model to a FunctionGraph and back""" - with pm.Model(coords={"test_dim": range(3)}) as m_old: - x = pm.Normal("x") - y = pm.Deterministic("y", x + 1) - w = pm.HalfNormal("w", pm.math.exp(y)) - z = pm.Normal("z", y, w, observed=[0, 1, 2], dims=("test_dim",)) - pot = pm.Potential("pot", x * 2) + assert fn is pymc.model.fgraph.fgraph_from_model - m_fgraph, memo = fgraph_from_model(m_old) - assert isinstance(m_fgraph, FunctionGraph) + from pymc_experimental.utils.model_fgraph import model_from_fgraph as fn - assert isinstance(memo[x].owner.op, ModelFreeRV) - assert isinstance(memo[y].owner.op, ModelDeterministic) - assert isinstance(memo[w].owner.op, ModelFreeRV) - assert isinstance(memo[z].owner.op, ModelObservedRV) - assert isinstance(memo[pot].owner.op, ModelPotential) + assert fn is pymc.model.fgraph.model_from_fgraph - m_new = model_from_fgraph(m_fgraph) - assert isinstance(m_new, pm.Model) + from pymc_experimental.utils.model_fgraph import clone_model as fn - assert m_new.coords == {"test_dim": tuple(range(3))} - assert m_new._dim_lengths["test_dim"].eval() == 3 - assert m_new.named_vars_to_dims == {"z": ["test_dim"]} - - named_vars = {"x", "y", "w", "z", "pot"} - assert set(m_new.named_vars) == named_vars - for named_var in named_vars: - assert m_new[named_var] is not m_old[named_var] - for value_new, value_old in zip(m_new.rvs_to_values.values(), m_old.rvs_to_values.values()): - # Constants are not cloned - if not isinstance(value_new, Constant): - assert value_new is not value_old - assert m_new["x"] in m_new.free_RVs - assert m_new["w"] in m_new.free_RVs - assert m_new["y"] in m_new.deterministics - assert m_new["z"] in m_new.observed_RVs - assert m_new["pot"] in m_new.potentials - assert m_new.rvs_to_transforms[m_new["x"]] is None - assert m_new.rvs_to_transforms[m_new["w"]] is pm.distributions.transforms.log - assert m_new.rvs_to_transforms[m_new["z"]] is None - - # Test random - new_y_draw, new_z_draw = pm.draw([m_new["y"], m_new["z"]], draws=5, random_seed=1) - old_y_draw, old_z_draw = pm.draw([m_old["y"], m_old["z"]], draws=5, random_seed=1) - np.testing.assert_array_equal(new_y_draw, old_y_draw) - np.testing.assert_array_equal(new_z_draw, old_z_draw) - - # Test logp - ip = m_new.initial_point() - np.testing.assert_equal( - m_new.compile_logp()(ip), - m_old.compile_logp()(ip), - ) - - -def same_storage(shared_1, shared_2) -> bool: - """Check if two shared variables have the same storage containers (i.e., they point to the same memory).""" - return shared_1.container.storage is shared_2.container.storage - - -@pytest.mark.parametrize("inline_views", (False, True)) -def test_data(inline_views): - """Test shared RNGs, MutableData, ConstantData and dim lengths are handled correctly. - - All model-related shared variables should be copied to become independent across models. - """ - with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old: - x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",)) - y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",)) - b0 = pm.ConstantData("b0", np.zeros((1,))) - b1 = pm.DiracDelta("b1", 1.0) - mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",)) - obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",)) - - m_fgraph, memo = fgraph_from_model(m_old, inlined_views=inline_views) - assert isinstance(memo[x].owner.op, ModelNamed) - assert isinstance(memo[y].owner.op, ModelNamed) - assert isinstance(memo[b0].owner.op, ModelNamed) - mu_inp = memo[mu].owner.inputs[0] - obs = memo[obs] - if not inline_views: - # Add(b0, Mul(FreeRV(b1), x) not Add(Named(b0), Mul(FreeRV(b1), Named(x)) - assert mu_inp.owner.inputs[0] is memo[b0].owner.inputs[0] - assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x].owner.inputs[0] - # ObservedRV(obs, y, *dims) not ObservedRV(obs, Named(y), *dims) - assert obs.owner.inputs[1] is memo[y].owner.inputs[0] - else: - assert mu_inp.owner.inputs[0] is memo[b0] - assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x] - assert obs.owner.inputs[1] is memo[y] - - m_new = model_from_fgraph(m_fgraph) - - # The rv-data mapping is preserved - assert m_new.rvs_to_values[m_new["obs"]] is m_new["y"] - - # ConstantData is still accessible as a model variable - np.testing.assert_array_equal(m_new["b0"], m_old["b0"]) - - # Shared model variables, dim lengths, and rngs are copied and no longer point to the same memory - assert not same_storage(m_new["x"], x) - assert not same_storage(m_new["y"], y) - assert not same_storage(m_new["b1"].owner.inputs[0], b1.owner.inputs[0]) - assert not same_storage(m_new.dim_lengths["test_dim"], m_old.dim_lengths["test_dim"]) - - # Updating model shared variables in new model, doesn't affect old one - with m_new: - pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)}) - assert m_new.dim_lengths["test_dim"].eval() == 2 - assert m_old.dim_lengths["test_dim"].eval() == 3 - np.testing.assert_allclose(pm.draw(m_new["mu"]), [100.0, 200.0]) - np.testing.assert_allclose(pm.draw(m_old["mu"]), [0.0, 1.0, 2.0], atol=1e-6) - - -@config.change_flags(floatX="float64") # Avoid downcasting Ops in the graph -def test_shared_variable(): - """Test that user defined shared variables (other than RNGs) aren't copied.""" - x = shared(np.array([1, 2, 3.0]), name="x") - y = shared(np.array([1, 2, 3.0]), name="y") - - with pm.Model() as m_old: - test = pm.Normal("test", mu=x, observed=y) - - assert test.owner.inputs[3] is x - assert m_old.rvs_to_values[test] is y - - m_new = clone_model(m_old) - test_new = m_new["test"] - # Shared Variables are cloned but still point to the same memory - assert test_new.owner.inputs[3] is not x - assert m_new.rvs_to_values[test_new] is not y - assert same_storage(test_new.owner.inputs[3], x) - assert same_storage(m_new.rvs_to_values[test_new], y) - - -@pytest.mark.parametrize("inline_views", (False, True)) -def test_deterministics(inline_views): - """Test handling of deterministics. - - We don't want Deterministics in the middle of the FunctionGraph, as they would make rewrites cumbersome - However we want them in the middle of Model.basic_RVs, so they display nicely in graphviz - - There is one edge case that has to be considered, when a Deterministic is just a copy of a RV. - In that case we don't bother to reintroduce it in between other Model.basic_RVs - """ - with pm.Model() as m: - x = pm.Normal("x") - mu = pm.Deterministic("mu", pm.math.abs(x)) - sigma = pm.math.exp(x) - pm.Deterministic("sigma", sigma) - y = pm.Normal("y", mu, sigma) - # Special case where the Deterministic - # is a direct view on another model variable - y_ = pm.Deterministic("y_", y) - # Just for kicks, make it a double one! - y__ = pm.Deterministic("y__", y_) - z = pm.Normal("z", y__) - - # Deterministic mu is in the graph of x to y but not sigma - assert m["y"].owner.inputs[3] is m["mu"] - assert m["y"].owner.inputs[4] is not m["sigma"] - - fg, _ = fgraph_from_model(m, inlined_views=inline_views) - - # Check that no Deterministics are in graph of x to y and y to z - x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs - # [Det(mu), Det(sigma)] - mu = det_mu.owner.inputs[0] - sigma = det_sigma.owner.inputs[0] - assert y.owner.inputs[0].owner.inputs[4] is sigma - assert det_y_ is not det_y__ - assert det_y_.owner.inputs[0] is y - if not inline_views: - # FreeRV(y(mu, sigma)) not FreeRV(y(Det(mu), Det(sigma))) - assert y.owner.inputs[0].owner.inputs[3] is mu - # FreeRV(z(y)) not FreeRV(z(Det(Det(y)))) - assert z.owner.inputs[0].owner.inputs[3] is y - # Det(y), not Det(Det(y)) - assert det_y__.owner.inputs[0] is y - else: - assert y.owner.inputs[0].owner.inputs[3] is det_mu - assert z.owner.inputs[0].owner.inputs[3] is det_y__ - assert det_y__.owner.inputs[0] is det_y_ - - # Both mu and sigma deterministics are now in the graph of x to y - m = model_from_fgraph(fg) - assert m["y"].owner.inputs[3] is m["mu"] - assert m["y"].owner.inputs[4] is m["sigma"] - # But not y_* in y to z, since there was no real Op in between - assert m["z"].owner.inputs[3] is m["y"] - assert m["y_"].owner.inputs[0] is m["y"] - assert m["y__"].owner.inputs[0] is m["y"] - - -def test_context_error(): - """Test that model_from_fgraph fails when called inside a Model context. - - We can't allow it, because the new Model that's returned would be a child of whatever Model context is active. - """ - with pm.Model() as m: - x = pm.Normal("x") - - fg = fgraph_from_model(m) - - with pytest.raises(RuntimeError, match="cannot be called inside a PyMC model context"): - model_from_fgraph(fg) - - -def test_sub_model_error(): - """Test Error is raised when trying to convert a sub-model to fgraph.""" - with pm.Model() as m: - x = pm.Beta("x", 1, 1) - with pm.Model() as sub_m: - y = pm.Normal("y", x) - - nodes = [v for v in fgraph_from_model(m)[0].toposort() if not isinstance(v.op, ModelVar)] - assert len(nodes) == 2 - assert isinstance(nodes[0].op, pm.Beta) - assert isinstance(nodes[1].op, pm.Normal) - - with pytest.raises(ValueError, match="Nested sub-models cannot be converted"): - fgraph_from_model(sub_m) - - -@pytest.fixture() -def non_centered_rewrite(): - @node_rewriter(tracks=[ModelFreeRV]) - def non_centered_param(fgraph: FunctionGraph, node): - """Rewrite that replaces centered normal by non-centered parametrization.""" - - rv, value, *dims = node.inputs - if not isinstance(rv.owner.op, pm.Normal): - return - rng, size, dtype, loc, scale = rv.owner.inputs - - # Only apply rewrite if size information is explicit - if size.ndim == 0: - return None - - try: - is_unit = ( - pt.get_underlying_scalar_constant_value(loc) == 0 - and pt.get_underlying_scalar_constant_value(scale) == 1 - ) - except NotScalarConstantError: - is_unit = False - - # Nothing to do here - if is_unit: - return - - raw_norm = pm.Normal.dist(0, 1, size=size, rng=rng) - raw_norm.name = f"{rv.name}_raw_" - raw_norm_value = raw_norm.clone() - fgraph.add_input(raw_norm_value) - raw_norm = model_free_rv(raw_norm, raw_norm_value, node.op.transform, *dims) - - new_norm = loc + raw_norm * scale - new_norm.name = rv.name - new_norm_det = model_deterministic(new_norm, *dims) - fgraph.add_output(new_norm_det) - - return [new_norm] - - return in2out(non_centered_param) - - -def test_fgraph_rewrite(non_centered_rewrite): - """Test we can apply a simple rewrite to a PyMC Model.""" - - with pm.Model(coords={"subject": range(10)}) as m_old: - group_mean = pm.Normal("group_mean") - group_std = pm.HalfNormal("group_std") - subject_mean = pm.Normal("subject_mean", group_mean, group_std, dims=("subject",)) - obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10), dims=("subject",)) - - fg, _ = fgraph_from_model(m_old) - non_centered_rewrite.apply(fg) - - m_new = model_from_fgraph(fg) - assert m_new.named_vars_to_dims == { - "subject_mean": ["subject"], - "subject_mean_raw_": ["subject"], - "obs": ["subject"], - } - assert set(m_new.named_vars) == { - "group_mean", - "group_std", - "subject_mean_raw_", - "subject_mean", - "obs", - } - assert {rv.name for rv in m_new.free_RVs} == {"group_mean", "group_std", "subject_mean_raw_"} - assert {rv.name for rv in m_new.observed_RVs} == {"obs"} - assert {rv.name for rv in m_new.deterministics} == {"subject_mean"} - - with pm.Model() as m_ref: - group_mean = pm.Normal("group_mean") - group_std = pm.HalfNormal("group_std") - subject_mean_raw = pm.Normal("subject_mean_raw_", 0, 1, shape=(10,)) - subject_mean = pm.Deterministic("subject_mean", group_mean + subject_mean_raw * group_std) - obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10)) - - np.testing.assert_array_equal( - pm.draw(m_new["subject_mean_raw_"], draws=7, random_seed=1), - pm.draw(m_ref["subject_mean_raw_"], draws=7, random_seed=1), - ) - - ip = m_new.initial_point() - np.testing.assert_equal( - m_new.compile_logp()(ip), - m_ref.compile_logp()(ip), - ) + assert fn is pymc.model.fgraph.clone_model diff --git a/pymc_experimental/utils/__init__.py b/pymc_experimental/utils/__init__.py index 705d2107..7844237d 100644 --- a/pymc_experimental/utils/__init__.py +++ b/pymc_experimental/utils/__init__.py @@ -15,10 +15,8 @@ from pymc_experimental.utils import prior, spline from pymc_experimental.utils.linear_cg import linear_cg -from pymc_experimental.utils.model_fgraph import clone_model __all__ = ( - "clone_model", "linear_cg", "prior", "spline", diff --git a/pymc_experimental/utils/model_fgraph.py b/pymc_experimental/utils/model_fgraph.py index 706ff613..285b5189 100644 --- a/pymc_experimental/utils/model_fgraph.py +++ b/pymc_experimental/utils/model_fgraph.py @@ -1,391 +1,6 @@ -from copy import copy -from typing import Dict, Optional, Sequence, Tuple +# pylint: disable=unused-import +import warnings -import pytensor -from pymc.logprob.transforms import RVTransform -from pymc.model import Model -from pymc.pytensorf import find_rng_nodes -from pytensor import Variable, shared -from pytensor.compile import SharedVariable -from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter -from pytensor.graph.rewriting.basic import out2in -from pytensor.scalar import Identity -from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.sharedvar import ScalarSharedVariable +from pymc.model.fgraph import clone_model, fgraph_from_model, model_from_fgraph -from pymc_experimental.utils.pytensorf import StringType - - -class ModelVar(Op): - """A dummy Op that describes the purpose of a Model variable and contains - meta-information as additional inputs (value and dims). - """ - - def make_node(self, rv, *dims): - assert isinstance(rv, Variable) - dims = self._parse_dims(rv, *dims) - return Apply(self, [rv, *dims], [rv.type(name=rv.name)]) - - def _parse_dims(self, rv, *dims): - if dims: - dims = [pytensor.as_symbolic(dim) for dim in dims] - assert all(isinstance(dim.type, StringType) for dim in dims) - assert len(dims) == rv.type.ndim - return dims - - def infer_shape(self, fgraph, node, inputs_shape): - return [inputs_shape[0]] - - def do_constant_folding(self, fgraph, node): - return False - - def perform(self, *args, **kwargs): - raise RuntimeError("ModelVars should never be in a final graph!") - - -class ModelValuedVar(ModelVar): - - __props__ = ("transform",) - - def __init__(self, transform: Optional[RVTransform] = None): - if transform is not None and not isinstance(transform, RVTransform): - raise TypeError(f"transform must be None or RVTransform type, got {type(transform)}") - self.transform = transform - super().__init__() - - def make_node(self, rv, value, *dims): - assert isinstance(rv, Variable) - dims = self._parse_dims(rv, *dims) - if value is not None: - assert isinstance(value, Variable) - assert rv.type.in_same_class(value.type) - return Apply(self, [rv, value, *dims], [rv.type(name=rv.name)]) - - -class ModelFreeRV(ModelValuedVar): - pass - - -class ModelObservedRV(ModelValuedVar): - pass - - -class ModelPotential(ModelVar): - pass - - -class ModelDeterministic(ModelVar): - pass - - -class ModelNamed(ModelVar): - pass - - -def model_free_rv(rv, value, transform, *dims): - return ModelFreeRV(transform=transform)(rv, value, *dims) - - -model_observed_rv = ModelObservedRV() -model_potential = ModelPotential() -model_deterministic = ModelDeterministic() -model_named = ModelNamed() - - -def toposort_replace( - fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]], reverse: bool = False -) -> None: - """Replace multiple variables in topological order.""" - toposort = fgraph.toposort() - sorted_replacements = sorted( - replacements, - key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1, - reverse=reverse, - ) - fgraph.replace_all(sorted_replacements, import_missing=True) - - -@node_rewriter([Elemwise]) -def local_remove_identity(fgraph, node): - if isinstance(node.op.scalar_op, Identity): - return [node.inputs[0]] - - -remove_identity_rewrite = out2in(local_remove_identity) - - -def fgraph_from_model( - model: Model, inlined_views=False -) -> Tuple[FunctionGraph, Dict[Variable, Variable]]: - """Convert Model to FunctionGraph. - - See: model_from_fgraph - - Parameters - ---------- - model: PyMC model - inlined_views: bool, default False - Whether "view" variables (Deterministics and Data) should be inlined among RVs in the fgraph, - or show up as separate branches. - - Returns - ------- - fgraph: FunctionGraph - FunctionGraph that includes a copy of model variables, wrapped in dummy `ModelVar` Ops. - It should be possible to reconstruct a valid PyMC model using `model_from_fgraph`. - - memo: Dict - A dictionary mapping original model variables to the equivalent nodes in the fgraph. - """ - - if any(v is not None for v in model.rvs_to_initial_values.values()): - raise NotImplementedError("Cannot convert models with non-default initial_values") - - if model.parent is not None: - raise ValueError( - "Nested sub-models cannot be converted to fgraph. Convert the parent model instead" - ) - - # Collect PyTensor variables - rvs_to_values = model.rvs_to_values - rvs = list(rvs_to_values.keys()) - free_rvs = model.free_RVs - observed_rvs = model.observed_RVs - potentials = model.potentials - named_vars = model.named_vars.values() - # We copy Deterministics (Identity Op) so that they don't show in between "main" variables - # We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator - old_deterministics = model.deterministics - deterministics = [det if inlined_views else det.copy(det.name) for det in old_deterministics] - # Value variables (we also have to decide whether to inline named ones) - old_value_vars = list(rvs_to_values.values()) - unnamed_value_vars = [val for val in old_value_vars if val not in named_vars] - named_value_vars = [ - val if inlined_views else val.copy(val.name) for val in old_value_vars if val in named_vars - ] - value_vars = old_value_vars.copy() - if inlined_views: - # In this case we want to use the named_value_vars as the value_vars in RVs - for named_val in named_value_vars: - idx = value_vars.index(named_val) - value_vars[idx] = named_val - # Other variables that are in named_vars but are not any of the categories above - # E.g., MutableData, ConstantData, _dim_lengths - # We use the same trick as deterministics! - accounted_for = set(free_rvs + observed_rvs + potentials + old_deterministics + old_value_vars) - other_named_vars = [ - var if inlined_views else var.copy(var.name) - for var in named_vars - if var not in accounted_for - ] - - model_vars = ( - rvs + potentials + deterministics + other_named_vars + named_value_vars + unnamed_value_vars - ) - - memo = {} - - # Replace the following shared variables in the model: - # 1. RNGs - # 2. MutableData (could increase memory usage significantly) - # 3. Mutable coords dim lengths - shared_vars_to_copy = find_rng_nodes(model_vars) - shared_vars_to_copy += [v for v in model.dim_lengths.values() if isinstance(v, SharedVariable)] - shared_vars_to_copy += [v for v in model.named_vars.values() if isinstance(v, SharedVariable)] - for var in shared_vars_to_copy: - # FIXME: ScalarSharedVariables are converted to 0d numpy arrays internally, - # so calling shared(shared(5).get_value()) returns a different type: TensorSharedVariables! - # Furthermore, PyMC silently ignores mutable dim changes that are SharedTensorVariables... - # https://github.com/pymc-devs/pytensor/issues/396 - if isinstance(var, ScalarSharedVariable): - new_var = shared(var.get_value(borrow=False).item()) - else: - new_var = shared(var.get_value(borrow=False)) - - assert new_var.type == var.type - new_var.name = var.name - new_var.tag = copy(var.tag) - # We can replace input variables by placing them in the memo - memo[var] = new_var - - fgraph = FunctionGraph( - outputs=model_vars, - clone=True, - memo=memo, - copy_orphans=True, - copy_inputs=True, - ) - # Copy model meta-info to fgraph - fgraph._coords = model._coords.copy() - fgraph._dim_lengths = {k: memo.get(v, v) for k, v in model._dim_lengths.items()} - - rvs_to_transforms = model.rvs_to_transforms - named_vars_to_dims = model.named_vars_to_dims - - # Introduce dummy `ModelVar` Ops - free_rvs_to_transforms = {memo[k]: tr for k, tr in rvs_to_transforms.items()} - free_rvs_to_values = {memo[k]: memo[v] for k, v in zip(rvs, value_vars) if k in free_rvs} - observed_rvs_to_values = { - memo[k]: memo[v] for k, v in zip(rvs, value_vars) if k in observed_rvs - } - potentials = [memo[k] for k in potentials] - deterministics = [memo[k] for k in deterministics] - named_vars = [memo[k] for k in other_named_vars + named_value_vars] - - vars = fgraph.outputs - new_vars = [] - for var in vars: - dims = named_vars_to_dims.get(var.name, ()) - if var in free_rvs_to_values: - new_var = model_free_rv( - var, free_rvs_to_values[var], free_rvs_to_transforms[var], *dims - ) - elif var in observed_rvs_to_values: - new_var = model_observed_rv(var, observed_rvs_to_values[var], *dims) - elif var in potentials: - new_var = model_potential(var, *dims) - elif var in deterministics: - new_var = model_deterministic(var, *dims) - elif var in named_vars: - new_var = model_named(var, *dims) - else: - # Unnamed value variables - new_var = var - new_vars.append(new_var) - - replacements = tuple(zip(vars, new_vars)) - toposort_replace(fgraph, replacements, reverse=True) - - # Reference model vars in memo - inverse_memo = {v: k for k, v in memo.items()} - for var, model_var in replacements: - if not inlined_views and ( - model_var.owner and isinstance(model_var.owner.op, (ModelDeterministic, ModelNamed)) - ): - # Ignore extra identity that will be removed at the end - var = var.owner.inputs[0] - original_var = inverse_memo[var] - memo[original_var] = model_var - - # Remove the last outputs corresponding to unnamed value variables, now that they are graph inputs - first_idx_to_remove = len(fgraph.outputs) - len(unnamed_value_vars) - for _ in unnamed_value_vars: - fgraph.remove_output(first_idx_to_remove) - - # Now that we have Deterministic dummy Ops, we remove the noisy `Identity`s from the graph - remove_identity_rewrite.apply(fgraph) - - return fgraph, memo - - -def model_from_fgraph(fgraph: FunctionGraph) -> Model: - """Convert FunctionGraph to PyMC model. - - This requires nodes to be properly tagged with `ModelVar` dummy Ops. - - See: fgraph_from_model - """ - - def first_non_model_var(var): - if var.owner and isinstance(var.owner.op, ModelVar): - new_var = var.owner.inputs[0] - return first_non_model_var(new_var) - else: - return var - - model = Model() - if model.parent is not None: - raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context") - model._coords = getattr(fgraph, "_coords", {}) - model._dim_lengths = getattr(fgraph, "_dim_lengths", {}) - - # Replace dummy `ModelVar` Ops by the underlying variables, - fgraph = fgraph.clone() - model_dummy_vars = [ - model_node.outputs[0] - for model_node in fgraph.toposort() - if isinstance(model_node.op, ModelVar) - ] - model_dummy_vars_to_vars = { - # Deterministics could refer to other model variables directly, - # We make sure to replace them by the first non-model variable - dummy_var: first_non_model_var(dummy_var.owner.inputs[0]) - for dummy_var in model_dummy_vars - } - toposort_replace(fgraph, tuple(model_dummy_vars_to_vars.items())) - - # Populate new PyMC model mappings - for model_var in model_dummy_vars: - if isinstance(model_var.owner.op, ModelFreeRV): - var, value, *dims = model_var.owner.inputs - transform = model_var.owner.op.transform - model.free_RVs.append(var) - # PyMC does not allow setting transform when we pass a value_var. Why? - model.create_value_var(var, transform=None, value_var=value) - model.rvs_to_transforms[var] = transform - model.set_initval(var, initval=None) - elif isinstance(model_var.owner.op, ModelObservedRV): - var, value, *dims = model_var.owner.inputs - model.observed_RVs.append(var) - model.create_value_var(var, transform=None, value_var=value) - elif isinstance(model_var.owner.op, ModelPotential): - var, *dims = model_var.owner.inputs - model.potentials.append(var) - elif isinstance(model_var.owner.op, ModelDeterministic): - var, *dims = model_var.owner.inputs - # If a Deterministic is a direct view on an RV, copy it - if var in model.basic_RVs: - var = var.copy() - model.deterministics.append(var) - elif isinstance(model_var.owner.op, ModelNamed): - var, *dims = model_var.owner.inputs - else: - raise TypeError(f"Unexpected ModelVar type {type(model_var)}") - - var.name = model_var.name - dims = [dim.data for dim in dims] if dims else None - model.add_named_variable(var, dims=dims) - - return model - - -def clone_model(model: Model) -> Model: - """Clone a PyMC model. - - Recreates a PyMC model with clones of the original variables. - Shared variables will point to the same container but be otherwise different objects. - Constants are not cloned. - - - Examples - -------- - - .. code-block:: python - - import pymc as pm - from pymc_experimental.utils import clone_model - - with pm.Model() as m: - p = pm.Beta("p", 1, 1) - x = pm.Bernoulli("x", p=p, shape=(3,)) - - with clone_model(m) as clone_m: - # Access cloned variables by name - clone_x = clone_m["x"] - - # z will be part of clone_m but not m - z = pm.Deterministic("z", clone_x + 1) - - """ - return model_from_fgraph(fgraph_from_model(model)[0]) - - -def extract_dims(var) -> Tuple: - dims = () - node = var.owner - if node and isinstance(node.op, ModelVar): - if isinstance(node.op, ModelValuedVar): - dims = node.inputs[2:] - else: - dims = node.inputs[1:] - return dims +warnings.warn("The functionality in this module has been moved to PyMC")