Skip to content

Commit

Permalink
Use BlockedStep base class for FFBSStep and update output in-place
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 31, 2021
1 parent 25efe93 commit ec8dfdf
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 54 deletions.
115 changes: 69 additions & 46 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pymc3 as pm
import theano.scalar as ts
import theano.tensor as tt
from pymc3.step_methods.arraystep import ArrayStep, Competence
from pymc3.distributions.distribution import draw_values
from pymc3.step_methods.arraystep import ArrayStep, BlockedStep, Competence
from pymc3.util import get_untransformed_name
from theano.compile import optdb
from theano.graph.basic import Variable, graph_inputs
Expand All @@ -16,33 +17,40 @@
from theano.tensor.subtensor import AdvancedIncSubtensor1
from theano.tensor.var import TensorConstant

from pymc3_hmm.distributions import DiscreteMarkovChain
from pymc3_hmm.distributions import DiscreteMarkovChain, SwitchingProcess
from pymc3_hmm.utils import compute_trans_freqs

big: float = 1e20
small: float = 1.0 / big


def ffbs_astep(gamma_0: np.ndarray, Gammas: np.ndarray, log_lik: np.ndarray):
def ffbs_step(
gamma_0: np.ndarray,
Gammas: np.ndarray,
log_lik: np.ndarray,
alphas: np.ndarray,
out: np.ndarray,
):
"""Sample a forward-filtered backward-sampled (FFBS) state sequence.
Parameters
----------
gamma_0: np.ndarray
gamma_0
The initial state probabilities.
Gamma: np.ndarray
Gamma
The transition probability matrices. This array should take the shape
`(N, M, M)`, where `N` is the state sequence length and `M` is the number
of distinct states. If `N` is `1`, the single transition matrix will
broadcast across all elements of the state sequence.
log_lik: np.ndarray
``(N, M, M)``, where ``N`` is the state sequence length and ``M`` is
the number of distinct states. If ``N`` is ``1``, the single
transition matrix will broadcast across all elements of the state
sequence.
log_lik
An array of shape `(M, N)` consisting of the log-likelihood values for
each state value at each point in the sequence.
Returns
-------
samples: np.ndarray
An array of shape `(N,)` containing the FFBS sampled state sequence.
alphas
An array in which to store the forward probabilities.
out
An output array to be updated in-place with the posterior sample
states.
"""
# Number of observations
Expand All @@ -56,8 +64,6 @@ def ffbs_astep(gamma_0: np.ndarray, Gammas: np.ndarray, log_lik: np.ndarray):
gamma_0_normed: np.ndarray = gamma_0.copy()
gamma_0_normed /= np.sum(gamma_0)

# "Forward" probabilities
alphas: np.ndarray = np.empty((M, N), dtype=float)
# Previous forward probability
alpha_nm1: np.ndarray = gamma_0_normed

Expand All @@ -83,18 +89,15 @@ def ffbs_astep(gamma_0: np.ndarray, Gammas: np.ndarray, log_lik: np.ndarray):
alpha_nm1 = alpha_n
alphas[..., n] = alpha_n

# The FFBS samples
samples: np.ndarray = np.empty((N,), dtype=np.int8)

# The uniform samples used to sample the categorical states
unif_samples: np.ndarray = np.random.uniform(size=samples.shape)
unif_samples: np.ndarray = np.random.uniform(size=out.shape)

alpha_N: np.ndarray = alphas[..., N - 1]
beta_N: np.ndarray = alpha_N / alpha_N.sum()

state_np1: np.ndarray = np.searchsorted(beta_N.cumsum(), unif_samples[N - 1])

samples[N - 1] = state_np1
out[N - 1] = state_np1

beta_n: np.ndarray = np.empty((M,), dtype=float)

Expand All @@ -104,12 +107,12 @@ def ffbs_astep(gamma_0: np.ndarray, Gammas: np.ndarray, log_lik: np.ndarray):
beta_n /= np.sum(beta_n)

state_np1 = np.searchsorted(beta_n.cumsum(), unif_samples[n])
samples[n] = state_np1
out[n] = state_np1

return samples
return out


class FFBSStep(ArrayStep):
class FFBSStep(BlockedStep):
r"""Forward-filtering backward-sampling steps.
For a hidden Markov model with state sequence :math:`S_t`, observations
Expand All @@ -126,44 +129,64 @@ class FFBSStep(ArrayStep):

name = "ffbs"

def __init__(self, var, values=None, model=None):
def __init__(self, vars, values=None, model=None):

if len(vars) > 1:
raise ValueError("This sampler only takes one variable.")

(var,) = pm.inputvars(vars)

if not isinstance(var.distribution, DiscreteMarkovChain):
raise TypeError("This sampler only samples `DiscreteMarkovChain`s.")

model = pm.modelcontext(model)

(var,) = pm.inputvars(var)
self.vars = [var]

self.dependent_rvs = [
v
for v in model.basic_RVs
if v is not var and var in graph_inputs([v.logpt])
]

# We compile a function--from a Theano graph--that computes the
# total log-likelihood values for each state in the sequence.
dependents_log_lik = model.fn(
tt.sum([v.logp_elemwiset for v in self.dependent_rvs], axis=0)
)
dep_comps_logp_stacked = []
for i, dependent_rv in enumerate(self.dependent_rvs):
if isinstance(dependent_rv.distribution, SwitchingProcess):
comp_logps = []

self.gamma_0_fn = model.fn(var.distribution.gamma_0)
self.Gammas_fn = model.fn(var.distribution.Gammas)
# Get the log-likelihoood sequences for each state in this
# `SwitchingProcess` observations distribution
for comp_dist in dependent_rv.distribution.comp_dists:
comp_logps.append(comp_dist.logp(dependent_rv))

comp_logp_stacked = tt.stack(comp_logps)
else:
raise TypeError(
"This sampler only supports `SwitchingProcess` observations"
)

super().__init__([var], [dependents_log_lik], allvars=True)
dep_comps_logp_stacked.append(comp_logp_stacked)

def astep(self, point, log_lik_fn, inputs):
gamma_0 = self.gamma_0_fn(inputs)
Gammas_t = self.Gammas_fn(inputs)
comp_logp_stacked = tt.sum(dep_comps_logp_stacked, axis=0)

M = gamma_0.shape[-1]
N = point.shape[-1]
(M,) = draw_values([var.distribution.gamma_0.shape[-1]], point=model.test_point)
N = model.test_point[var.name].shape[-1]
self.alphas = np.empty((M, N), dtype=float)

# TODO: Why won't broadcasting work with `log_lik_fn`? Seems like we
# could be missing out on a much more efficient/faster approach to this
# potentially large computation.
# state_seqs = np.broadcast_to(np.arange(M, dtype=int)[..., None], (M, N))
# log_lik_t = log_lik_fn(state_seqs)
log_lik_t = np.stack([log_lik_fn(np.broadcast_to(m, N)) for m in range(M)])
self.log_lik_states = model.fn(comp_logp_stacked)
self.gamma_0_fn = model.fn(var.distribution.gamma_0)
self.Gammas_fn = model.fn(var.distribution.Gammas)

return ffbs_astep(gamma_0, Gammas_t, log_lik_t)
def step(self, point):
gamma_0 = self.gamma_0_fn(point)
# TODO: Can we update these in-place (e.g. using a shared variable)?
Gammas_t = self.Gammas_fn(point)
# TODO: Can we update these in-place (e.g. using a shared variable)?
log_lik_state_vals = self.log_lik_states(point)
ffbs_step(
gamma_0, Gammas_t, log_lik_state_vals, self.alphas, point[self.vars[0].name]
)
return point

@staticmethod
def competence(var):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_time_varying_model():

sim_point = pm.sample_prior_predictive(samples=1, model=sim_model)

y_t = sim_point["Y_t"].squeeze()
y_t = sim_point["Y_t"].squeeze().astype(int)

split = int(len(y_t) * 0.7)

Expand Down Expand Up @@ -155,7 +155,7 @@ def test_time_varying_model():
)

# Update the shared variable values
Y.set_value(np.ones(test_X.shape[0]))
Y.set_value(np.ones(test_X.shape[0], dtype=Y.dtype))
X.set_value(test_X)

model.V_t.distribution.shape = (test_X.shape[0],)
Expand Down
43 changes: 37 additions & 6 deletions tests/test_step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from theano.graph.op import get_test_value

from pymc3_hmm.distributions import DiscreteMarkovChain, PoissonZeroProcess
from pymc3_hmm.step_methods import FFBSStep, TransMatConjugateStep, ffbs_astep
from pymc3_hmm.step_methods import FFBSStep, TransMatConjugateStep, ffbs_step
from pymc3_hmm.utils import compute_steady_state, compute_trans_freqs
from tests.utils import simulate_poiszero_hmm

Expand All @@ -24,7 +24,7 @@ def raise_under_overflow():
pytestmark = pytest.mark.usefixtures("raise_under_overflow")


def test_ffbs_astep():
def test_ffbs_step():

np.random.seed(2032)

Expand All @@ -36,13 +36,17 @@ def test_ffbs_astep():
test_log_lik_0 = np.stack(
[np.broadcast_to(0.0, 10000), np.broadcast_to(-np.inf, 10000)]
)
res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_0)
alphas = np.empty(test_log_lik_0.shape)
res = np.empty(test_log_lik_0.shape[-1])
ffbs_step(test_gamma_0, test_Gammas, test_log_lik_0, alphas, res)
assert np.all(res == 0)

test_log_lik_1 = np.stack(
[np.broadcast_to(-np.inf, 10000), np.broadcast_to(0.0, 10000)]
)
res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_1)
alphas = np.empty(test_log_lik_1.shape)
res = np.empty(test_log_lik_1.shape[-1])
ffbs_step(test_gamma_0, test_Gammas, test_log_lik_1, alphas, res)
assert np.all(res == 1)

# A well-separated mixture with non-degenerate likelihoods
Expand All @@ -59,7 +63,9 @@ def test_ffbs_astep():
# TODO FIXME: This is a statistically unsound/unstable check.
assert np.mean(np.abs(test_log_lik_p.argmax(0) - test_seq)) < 1e-2

res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_p)
alphas = np.empty(test_log_lik_p.shape)
res = np.empty(test_log_lik_p.shape[-1])
ffbs_step(test_gamma_0, test_Gammas, test_log_lik_p, alphas, res)
# TODO FIXME: This is a statistically unsound/unstable check.
assert np.mean(np.abs(res - test_seq)) < 1e-2

Expand All @@ -81,12 +87,37 @@ def test_ffbs_astep():
test_log_lik[::2] = test_log_lik[::2][:, ::-1]
test_log_lik = test_log_lik.T

res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik)
alphas = np.empty(test_log_lik.shape)
res = np.empty(test_log_lik.shape[-1])
ffbs_step(test_gamma_0, test_Gammas, test_log_lik, alphas, res)
assert np.array_equal(res, np.r_[1, 0, 0, 1])


def test_FFBSStep():

with pm.Model(), pytest.raises(ValueError):
P_rv = np.eye(2)[None, ...]
S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0], shape=10)
S_2_rv = DiscreteMarkovChain("S_2_t", P_rv, np.r_[0.0, 1.0], shape=10)
PoissonZeroProcess(
"Y_t", 9.0, S_rv + S_2_rv, observed=np.random.poisson(9.0, size=10)
)
# Only one variable can be sampled by this step method
ffbs = FFBSStep([S_rv, S_2_rv])

with pm.Model(), pytest.raises(TypeError):
S_rv = pm.Categorical("S_t", np.r_[1.0, 0.0], shape=10)
PoissonZeroProcess("Y_t", 9.0, S_rv, observed=np.random.poisson(9.0, size=10))
# Only `DiscreteMarkovChains` can be sampled with this step method
ffbs = FFBSStep([S_rv])

with pm.Model(), pytest.raises(TypeError):
P_rv = np.eye(2)[None, ...]
S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0], shape=10)
pm.Poisson("Y_t", S_rv, observed=np.random.poisson(9.0, size=10))
# Only `SwitchingProcess`es can used as dependent variables
ffbs = FFBSStep([S_rv])

np.random.seed(2032)

poiszero_sim, _ = simulate_poiszero_hmm(30, 150)
Expand Down

0 comments on commit ec8dfdf

Please sign in to comment.