From 9cf1c08d6b945ec48cff810b801916396cedb97b Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 25 May 2021 22:37:38 -0500 Subject: [PATCH] Convert step methods to v4 --- pymc3_hmm/step_methods.py | 72 +++++++++++++++-------- tests/test_step_methods.py | 114 +++++++++++++++++++++++-------------- 2 files changed, 117 insertions(+), 69 deletions(-) diff --git a/pymc3_hmm/step_methods.py b/pymc3_hmm/step_methods.py index f1d2517..97e3d0b 100644 --- a/pymc3_hmm/step_methods.py +++ b/pymc3_hmm/step_methods.py @@ -1,11 +1,9 @@ -from itertools import chain - import aesara.scalar as aes import aesara.tensor as at import numpy as np import pymc3 as pm from aesara.compile import optdb -from aesara.graph.basic import Variable, graph_inputs +from aesara.graph.basic import Variable, graph_inputs, vars_between from aesara.graph.fg import FunctionGraph from aesara.graph.op import get_test_value as test_value from aesara.graph.opt import OpRemove, pre_greedy_local_optimizer @@ -13,16 +11,33 @@ from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.subtensor import AdvancedIncSubtensor1 from aesara.tensor.var import TensorConstant +from pymc3.aesaraf import change_rv_size +from pymc3.distributions.logp import logpt from pymc3.step_methods.arraystep import ArrayStep, BlockedStep, Competence from pymc3.util import get_untransformed_name -from pymc3_hmm.distributions import DiscreteMarkovChain, SwitchingProcess +from pymc3_hmm.distributions import DiscreteMarkovChainFactory, SwitchingProcessFactory from pymc3_hmm.utils import compute_trans_freqs big: float = 1e20 small: float = 1.0 / big +def conform_rv_shape(rv_var, shape): + ndim_supp = rv_var.owner.op.ndim_supp + if ndim_supp > 0: + new_size = shape[:-ndim_supp] + else: + new_size = shape + + rv_var = change_rv_size(rv_var, new_size) + + if hasattr(rv_var.tag, "value_var"): + rv_var.tag.value_var = rv_var.type() + + return rv_var + + def ffbs_step( gamma_0: np.ndarray, Gammas: np.ndarray, @@ -133,9 +148,9 @@ def __init__(self, vars, values=None, model=None): if len(vars) > 1: raise ValueError("This sampler only takes one variable.") - (var,) = pm.inputvars(vars) + (var,) = vars - if not isinstance(var.distribution, DiscreteMarkovChain): + if not var.owner or not isinstance(var.owner.op, DiscreteMarkovChainFactory): raise TypeError("This sampler only samples `DiscreteMarkovChain`s.") model = pm.modelcontext(model) @@ -145,18 +160,27 @@ def __init__(self, vars, values=None, model=None): self.dependent_rvs = [ v for v in model.basic_RVs - if v is not var and var in graph_inputs([v.logpt]) + if v is not var and var in vars_between(list(graph_inputs([v])), [v]) ] + if not self.dependent_rvs: + raise ValueError(f"Could not find variables that depend on {var}") + dep_comps_logp_stacked = [] for i, dependent_rv in enumerate(self.dependent_rvs): - if isinstance(dependent_rv.distribution, SwitchingProcess): + if dependent_rv.owner and isinstance( + dependent_rv.owner.op, SwitchingProcessFactory + ): comp_logps = [] # Get the log-likelihoood sequences for each state in this # `SwitchingProcess` observations distribution - for comp_dist in dependent_rv.distribution.comp_dists: - comp_logps.append(comp_dist.logp(dependent_rv)) + for comp_dist in dependent_rv.owner.inputs[ + 4 : -len(dependent_rv.owner.op.shared_inputs) + ]: + new_comp_dist = conform_rv_shape(comp_dist, dependent_rv.shape) + state_logp = logpt(new_comp_dist, dependent_rv.tag.observations) + comp_logps.append(state_logp) comp_logp_stacked = at.stack(comp_logps) else: @@ -167,15 +191,18 @@ def __init__(self, vars, values=None, model=None): dep_comps_logp_stacked.append(comp_logp_stacked) comp_logp_stacked = at.sum(dep_comps_logp_stacked, axis=0) + self.log_lik_states = model.fn(comp_logp_stacked) + + Gammas_var = var.owner.inputs[1] + gamma_0_var = var.owner.inputs[2] - # XXX: This isn't correct. - M = var.owner.inputs[2].eval(model.test_point) - N = model.test_point[var.name].shape[-1] + Gammas_initial_shape = model.fn(Gammas_var.shape)(model.initial_point) + M = Gammas_initial_shape[-1] + N = Gammas_initial_shape[-3] self.alphas = np.empty((M, N), dtype=float) - self.log_lik_states = model.fn(comp_logp_stacked) - self.gamma_0_fn = model.fn(var.distribution.gamma_0) - self.Gammas_fn = model.fn(var.distribution.Gammas) + self.gamma_0_fn = model.fn(gamma_0_var) + self.Gammas_fn = model.fn(Gammas_var) def step(self, point): gamma_0 = self.gamma_0_fn(point) @@ -190,9 +217,8 @@ def step(self, point): @staticmethod def competence(var): - distribution = getattr(var.distribution, "parent_dist", var.distribution) - if isinstance(distribution, DiscreteMarkovChain): + if var.owner and isinstance(var.owner.op, DiscreteMarkovChainFactory): return Competence.IDEAL # elif isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types): # return Competence.COMPATIBLE @@ -242,7 +268,7 @@ def __init__(self, model_vars, values=None, model=None, rng=None): if isinstance(model_vars, Variable): model_vars = [model_vars] - model_vars = list(chain.from_iterable([pm.inputvars(v) for v in model_vars])) + model_vars = list(model_vars) # TODO: Are the rows in this matrix our `dir_priors`? dir_priors = [] @@ -256,7 +282,7 @@ def __init__(self, model_vars, values=None, model=None, rng=None): state_seqs = [ v for v in model.vars + model.observed_RVs - if isinstance(v.distribution, DiscreteMarkovChain) + if (v.owner.op and isinstance(v.owner.op, DiscreteMarkovChainFactory)) and all(d in graph_inputs([v.distribution.Gammas]) for d in dir_priors) ] @@ -429,11 +455,7 @@ def astep(self, point, inputs): @staticmethod def competence(var): - # TODO: Check that the dependent term is a conjugate type. - - distribution = getattr(var.distribution, "parent_dist", var.distribution) - - if isinstance(distribution, pm.Dirichlet): + if var.owner and isinstance(var.owner.op, pm.Dirichlet): return Competence.COMPATIBLE return Competence.INCOMPATIBLE diff --git a/tests/test_step_methods.py b/tests/test_step_methods.py index cf84af4..07b2555 100644 --- a/tests/test_step_methods.py +++ b/tests/test_step_methods.py @@ -1,5 +1,6 @@ import warnings +import aesara import aesara.tensor as at import numpy as np import pymc3 as pm @@ -24,6 +25,17 @@ def raise_under_overflow(): pytestmark = pytest.mark.usefixtures("raise_under_overflow") +def transform_var(model, rv_var): + value_var = model.rvs_to_values[rv_var] + transform = getattr(value_var.tag, "transform", None) + if transform is not None: + untrans_value_var = transform.forward(rv_var, value_var) + untrans_value_var.name = rv_var.name + return untrans_value_var + else: + return value_var + + def test_ffbs_step(): np.random.seed(2032) @@ -96,9 +108,9 @@ def test_ffbs_step(): def test_FFBSStep(): with pm.Model(), pytest.raises(ValueError): - P_rv = np.eye(2)[None, ...] - S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0], shape=10) - S_2_rv = DiscreteMarkovChain("S_2_t", P_rv, np.r_[0.0, 1.0], shape=10) + P_rv = np.broadcast_to(np.eye(2), (10, 2, 2)) + S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0]) + S_2_rv = DiscreteMarkovChain("S_2_t", P_rv, np.r_[0.0, 1.0]) PoissonZeroProcess( "Y_t", 9.0, S_rv + S_2_rv, observed=np.random.poisson(9.0, size=10) ) @@ -106,14 +118,14 @@ def test_FFBSStep(): ffbs = FFBSStep([S_rv, S_2_rv]) with pm.Model(), pytest.raises(TypeError): - S_rv = pm.Categorical("S_t", np.r_[1.0, 0.0], shape=10) + S_rv = pm.Categorical("S_t", np.r_[1.0, 0.0], size=10) PoissonZeroProcess("Y_t", 9.0, S_rv, observed=np.random.poisson(9.0, size=10)) # Only `DiscreteMarkovChains` can be sampled with this step method ffbs = FFBSStep([S_rv]) with pm.Model(), pytest.raises(TypeError): - P_rv = np.eye(2)[None, ...] - S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0], shape=10) + P_rv = np.broadcast_to(np.eye(2), (10, 2, 2)) + S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0]) pm.Poisson("Y_t", S_rv, observed=np.random.poisson(9.0, size=10)) # Only `SwitchingProcess`es can used as dependent variables ffbs = FFBSStep([S_rv]) @@ -124,26 +136,36 @@ def test_FFBSStep(): y_test = poiszero_sim["Y_t"] with pm.Model() as test_model: - p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) - p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2) + p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) + p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) P_tt = at.stack([p_0_rv, p_1_rv]) - P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt)) - pi_0_tt = compute_steady_state(P_rv) + P_rv = pm.Deterministic( + "P_tt", at.broadcast_to(P_tt, (y_test.shape[0],) + tuple(P_tt.shape)) + ) + + pi_0_tt = compute_steady_state(P_tt) - S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=y_test.shape[0]) + S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt) PoissonZeroProcess("Y_t", 9.0, S_rv, observed=y_test) with test_model: ffbs = FFBSStep([S_rv]) - test_point = test_model.test_point.copy() - test_point["p_0_stickbreaking__"] = poiszero_sim["p_0_stickbreaking__"] - test_point["p_1_stickbreaking__"] = poiszero_sim["p_1_stickbreaking__"] + p_0_stickbreaking__fn = aesara.function( + [test_model.rvs_to_values[p_0_rv]], transform_var(test_model, p_0_rv) + ) + p_1_stickbreaking__fn = aesara.function( + [test_model.rvs_to_values[p_1_rv]], transform_var(test_model, p_1_rv) + ) + + initial_point = test_model.initial_point.copy() + initial_point["p_0_stickbreaking__"] = p_0_stickbreaking__fn(poiszero_sim["p_0"]) + initial_point["p_1_stickbreaking__"] = p_1_stickbreaking__fn(poiszero_sim["p_1"]) - res = ffbs.step(test_point) + res = ffbs.step(initial_point) assert np.array_equal(res["S_t"], poiszero_sim["S_t"]) @@ -162,11 +184,13 @@ def test_FFBSStep_extreme(): p_1_rv = poiszero_sim["p_1"] P_tt = at.stack([p_0_rv, p_1_rv]) - P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt)) + P_rv = pm.Deterministic( + "P_tt", at.broadcast_to(P_tt, (y_test.shape[0],) + tuple(P_tt.shape)) + ) pi_0_tt = poiszero_sim["pi_0"] - S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=y_test.shape[0]) + S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt) S_rv.tag.test_value = (y_test > 0).astype(int) # This prior is very far from the true value... @@ -178,12 +202,8 @@ def test_FFBSStep_extreme(): with test_model: ffbs = FFBSStep([S_rv]) - test_point = test_model.test_point.copy() - test_point["p_0_stickbreaking__"] = poiszero_sim["p_0_stickbreaking__"] - test_point["p_1_stickbreaking__"] = poiszero_sim["p_1_stickbreaking__"] - with np.errstate(over="ignore", under="ignore"): - res = ffbs.step(test_point) + res = ffbs.step(test_model.initial_point) assert np.array_equal(res["S_t"], poiszero_sim["S_t"]) @@ -213,7 +233,7 @@ def test_FFBSStep_extreme(): def test_TransMatConjugateStep(): with pm.Model() as test_model, pytest.raises(ValueError): - p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) + p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) transmat = TransMatConjugateStep(p_0_rv) np.random.seed(2032) @@ -222,25 +242,27 @@ def test_TransMatConjugateStep(): y_test = poiszero_sim["Y_t"] with pm.Model() as test_model: - p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) - p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2) + p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) + p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) P_tt = at.stack([p_0_rv, p_1_rv]) - P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt)) + P_rv = pm.Deterministic( + "P_tt", at.broadcast_to(P_tt, (y_test.shape[0],) + tuple(P_tt.shape)) + ) - pi_0_tt = compute_steady_state(P_rv) + pi_0_tt = compute_steady_state(P_tt) - S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=y_test.shape[0]) + S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt) PoissonZeroProcess("Y_t", 9.0, S_rv, observed=y_test) with test_model: transmat = TransMatConjugateStep(P_rv) - test_point = test_model.test_point.copy() - test_point["S_t"] = (y_test > 0).astype(int) + initial_point = test_model.initial_point.copy() + initial_point["S_t"] = (y_test > 0).astype(int) - res = transmat.step(test_point) + res = transmat.step(initial_point) p_0_smpl = get_test_value( p_0_rv.distribution.transform.backward(res[p_0_rv.transformed.name]) @@ -265,8 +287,8 @@ def test_TransMatConjugateStep_subtensors(): # Confirm that Dirichlet/non-Dirichlet mixed rows can be # parsed with pm.Model(): - d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) - d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2) + d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) + d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) p_0_rv = at.as_tensor([0, 0, 1]) p_1_rv = at.zeros(3) @@ -275,8 +297,10 @@ def test_TransMatConjugateStep_subtensors(): p_2_rv = at.set_subtensor(p_1_rv[[1, 2]], d_1_rv) P_tt = at.stack([p_0_rv, p_1_rv, p_2_rv]) - P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt)) - DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0], shape=(10,)) + P_rv = pm.Deterministic( + "P_tt", at.broadcast_to(P_tt, (10,) + tuple(P_tt.shape)) + ) + DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0]) transmat = TransMatConjugateStep(P_rv) @@ -289,8 +313,8 @@ def test_TransMatConjugateStep_subtensors(): # Same thing, just with some manipulations of the transition matrix with pm.Model(): - d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) - d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2) + d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) + d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) p_0_rv = at.as_tensor([0, 0, 1]) p_1_rv = at.zeros(3) @@ -301,8 +325,10 @@ def test_TransMatConjugateStep_subtensors(): P_tt = at.horizontal_stack( p_0_rv[..., None], p_1_rv[..., None], p_2_rv[..., None] ) - P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt.T)) - DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0], shape=(10,)) + P_rv = pm.Deterministic( + "P_tt", at.broadcast_to(P_tt.T, (10,) + tuple(P_tt.T.shape)) + ) + DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0]) transmat = TransMatConjugateStep(P_rv) @@ -315,8 +341,8 @@ def test_TransMatConjugateStep_subtensors(): # Use an observed `DiscreteMarkovChain` and check the conjugate results with pm.Model(): - d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) - d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2) + d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) + d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) p_0_rv = at.as_tensor([0, 0, 1]) p_1_rv = at.zeros(3) @@ -327,9 +353,9 @@ def test_TransMatConjugateStep_subtensors(): P_tt = at.horizontal_stack( p_0_rv[..., None], p_1_rv[..., None], p_2_rv[..., None] ) - P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt.T)) - DiscreteMarkovChain( - "S_t", P_rv, np.r_[1, 0, 0], shape=(4,), observed=np.r_[0, 1, 0, 2] + P_rv = pm.Deterministic( + "P_tt", at.broadcast_to(P_tt.T, (4,) + tuple(P_tt.T.shape)) ) + DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0], observed=np.r_[0, 1, 0, 2]) transmat = TransMatConjugateStep(P_rv)