Skip to content

Commit

Permalink
Convert step methods to v4
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 27, 2021
1 parent 2b70311 commit 9cf1c08
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 69 deletions.
72 changes: 47 additions & 25 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,43 @@
from itertools import chain

import aesara.scalar as aes
import aesara.tensor as at
import numpy as np
import pymc3 as pm
from aesara.compile import optdb
from aesara.graph.basic import Variable, graph_inputs
from aesara.graph.basic import Variable, graph_inputs, vars_between
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value as test_value
from aesara.graph.opt import OpRemove, pre_greedy_local_optimizer
from aesara.graph.optdb import Query
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.subtensor import AdvancedIncSubtensor1
from aesara.tensor.var import TensorConstant
from pymc3.aesaraf import change_rv_size
from pymc3.distributions.logp import logpt
from pymc3.step_methods.arraystep import ArrayStep, BlockedStep, Competence
from pymc3.util import get_untransformed_name

from pymc3_hmm.distributions import DiscreteMarkovChain, SwitchingProcess
from pymc3_hmm.distributions import DiscreteMarkovChainFactory, SwitchingProcessFactory
from pymc3_hmm.utils import compute_trans_freqs

big: float = 1e20
small: float = 1.0 / big


def conform_rv_shape(rv_var, shape):
ndim_supp = rv_var.owner.op.ndim_supp
if ndim_supp > 0:
new_size = shape[:-ndim_supp]
else:
new_size = shape

rv_var = change_rv_size(rv_var, new_size)

if hasattr(rv_var.tag, "value_var"):
rv_var.tag.value_var = rv_var.type()

return rv_var


def ffbs_step(
gamma_0: np.ndarray,
Gammas: np.ndarray,
Expand Down Expand Up @@ -133,9 +148,9 @@ def __init__(self, vars, values=None, model=None):
if len(vars) > 1:
raise ValueError("This sampler only takes one variable.")

(var,) = pm.inputvars(vars)
(var,) = vars

if not isinstance(var.distribution, DiscreteMarkovChain):
if not var.owner or not isinstance(var.owner.op, DiscreteMarkovChainFactory):
raise TypeError("This sampler only samples `DiscreteMarkovChain`s.")

model = pm.modelcontext(model)
Expand All @@ -145,18 +160,27 @@ def __init__(self, vars, values=None, model=None):
self.dependent_rvs = [
v
for v in model.basic_RVs
if v is not var and var in graph_inputs([v.logpt])
if v is not var and var in vars_between(list(graph_inputs([v])), [v])
]

if not self.dependent_rvs:
raise ValueError(f"Could not find variables that depend on {var}")

dep_comps_logp_stacked = []
for i, dependent_rv in enumerate(self.dependent_rvs):
if isinstance(dependent_rv.distribution, SwitchingProcess):
if dependent_rv.owner and isinstance(
dependent_rv.owner.op, SwitchingProcessFactory
):
comp_logps = []

# Get the log-likelihoood sequences for each state in this
# `SwitchingProcess` observations distribution
for comp_dist in dependent_rv.distribution.comp_dists:
comp_logps.append(comp_dist.logp(dependent_rv))
for comp_dist in dependent_rv.owner.inputs[
4 : -len(dependent_rv.owner.op.shared_inputs)
]:
new_comp_dist = conform_rv_shape(comp_dist, dependent_rv.shape)
state_logp = logpt(new_comp_dist, dependent_rv.tag.observations)
comp_logps.append(state_logp)

comp_logp_stacked = at.stack(comp_logps)
else:
Expand All @@ -167,15 +191,18 @@ def __init__(self, vars, values=None, model=None):
dep_comps_logp_stacked.append(comp_logp_stacked)

comp_logp_stacked = at.sum(dep_comps_logp_stacked, axis=0)
self.log_lik_states = model.fn(comp_logp_stacked)

Gammas_var = var.owner.inputs[1]
gamma_0_var = var.owner.inputs[2]

# XXX: This isn't correct.
M = var.owner.inputs[2].eval(model.test_point)
N = model.test_point[var.name].shape[-1]
Gammas_initial_shape = model.fn(Gammas_var.shape)(model.initial_point)
M = Gammas_initial_shape[-1]
N = Gammas_initial_shape[-3]
self.alphas = np.empty((M, N), dtype=float)

self.log_lik_states = model.fn(comp_logp_stacked)
self.gamma_0_fn = model.fn(var.distribution.gamma_0)
self.Gammas_fn = model.fn(var.distribution.Gammas)
self.gamma_0_fn = model.fn(gamma_0_var)
self.Gammas_fn = model.fn(Gammas_var)

def step(self, point):
gamma_0 = self.gamma_0_fn(point)
Expand All @@ -190,9 +217,8 @@ def step(self, point):

@staticmethod
def competence(var):
distribution = getattr(var.distribution, "parent_dist", var.distribution)

if isinstance(distribution, DiscreteMarkovChain):
if var.owner and isinstance(var.owner.op, DiscreteMarkovChainFactory):
return Competence.IDEAL
# elif isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types):
# return Competence.COMPATIBLE
Expand Down Expand Up @@ -242,7 +268,7 @@ def __init__(self, model_vars, values=None, model=None, rng=None):
if isinstance(model_vars, Variable):
model_vars = [model_vars]

model_vars = list(chain.from_iterable([pm.inputvars(v) for v in model_vars]))
model_vars = list(model_vars)

# TODO: Are the rows in this matrix our `dir_priors`?
dir_priors = []
Expand All @@ -256,7 +282,7 @@ def __init__(self, model_vars, values=None, model=None, rng=None):
state_seqs = [
v
for v in model.vars + model.observed_RVs
if isinstance(v.distribution, DiscreteMarkovChain)
if (v.owner.op and isinstance(v.owner.op, DiscreteMarkovChainFactory))
and all(d in graph_inputs([v.distribution.Gammas]) for d in dir_priors)
]

Expand Down Expand Up @@ -429,11 +455,7 @@ def astep(self, point, inputs):
@staticmethod
def competence(var):

# TODO: Check that the dependent term is a conjugate type.

distribution = getattr(var.distribution, "parent_dist", var.distribution)

if isinstance(distribution, pm.Dirichlet):
if var.owner and isinstance(var.owner.op, pm.Dirichlet):
return Competence.COMPATIBLE

return Competence.INCOMPATIBLE
Loading

0 comments on commit 9cf1c08

Please sign in to comment.