Skip to content

Commit

Permalink
use shared variable for FFB
Browse files Browse the repository at this point in the history
  • Loading branch information
xjing76 committed Mar 22, 2021
1 parent 3d2517d commit f9abf35
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 25 deletions.
5 changes: 2 additions & 3 deletions pymc3_hmm/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pymc3 as pm
import theano
import theano.tensor as tt
from theano import shared
from pymc3.distributions.distribution import _DrawValuesContext, draw_values
from pymc3.distributions.mixture import _conversion_map, all_discrete
from theano.graph.op import get_test_value
Expand Down Expand Up @@ -190,8 +189,8 @@ def logp(self, obs):

shape_var = tuple(obs_tt.shape.tag.test_value)

if self._logp_like is None or shape_var not in self._logp_like:
self._logp_like = {shape_var : tt.alloc(-np.inf, *shape_var)}
if self._logp_like is None or shape_var not in self._logp_like:
self._logp_like = {shape_var: tt.alloc(-np.inf, *shape_var)}

logp_val = self._logp_like[shape_var]

Expand Down
46 changes: 28 additions & 18 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import theano.tensor as tt
from pymc3.step_methods.arraystep import ArrayStep, Competence
from pymc3.util import get_untransformed_name
from theano import shared
from theano.compile import optdb
from theano.graph.basic import Variable, graph_inputs
from theano.graph.fg import FunctionGraph
Expand Down Expand Up @@ -50,16 +51,13 @@ def ffbs_astep(
# Number of observations
N: int = log_lik.shape[-1]

# Number of states
M: int = gamma_0.shape[-1]
# assert M == log_lik.shape[-2]

# Initial state probabilities
gamma_0_normed: np.ndarray = gamma_0.copy()
gamma_0_normed /= np.sum(gamma_0)

# "Forward" probabilities
alphas: np.ndarray = np.empty((M, N), dtype=float)
alphas = lik_dict["alphas"]

# Previous forward probability
alpha_nm1: np.ndarray = gamma_0_normed

Expand All @@ -86,7 +84,7 @@ def ffbs_astep(
alphas[..., n] = alpha_n

# The FFBS samples
samples: np.ndarray = np.empty((N,), dtype=np.int8)
samples = lik_dict["samples"]

# The uniform samples used to sample the categorical states
unif_samples: np.ndarray = np.random.uniform(size=samples.shape)
Expand All @@ -98,7 +96,7 @@ def ffbs_astep(

samples[N - 1] = state_np1

beta_n: np.ndarray = np.empty((M,), dtype=float)
beta_n = lik_dict["beta_n"]

# Backward sampling
for n in range(N - 2, -1, -1):
Expand Down Expand Up @@ -142,13 +140,13 @@ def __init__(self, var, values=None, model=None):

# We compile a function--from a Theano graph--that computes the
# total log-likelihood values for each state in the sequence.
dependents_log_lik = model.fn(
tt.sum([v.logp_elemwiset for v in self.dependent_rvs], axis=0)
)

temp = tt.sum([v.logp_elemwiset for v in self.dependent_rvs], axis=0)
dependents_log_lik = model.fn(temp)

self.gamma_0_fn = model.fn(var.distribution.gamma_0)
self.Gammas_fn = model.fn(var.distribution.Gammas)
self.log_lik_t = None
self.log_lik_t = shared(np.zeros((1, 1)), name="log_lik_t", borrow=True)
self.lik_dict = {}

super().__init__([var], [dependents_log_lik], allvars=True)
Expand All @@ -163,21 +161,33 @@ def astep(self, point, log_lik_fn, inputs):
# TODO: Why won't broadcasting work with `log_lik_fn`? Seems like we
# could be missing out on a much more efficient/faster approach to this
# potentially large computation.
# state_seqs = np.broadcast_to(np.arange(M, dtype=int)[..., None], (M, N))
if self.log_lik_t is None:
log_lik_t = np.stack([log_lik_fn(np.broadcast_to(m, N)) for m in range(M)])
else:
log_lik_t = self.log_lik_t

if "lik_n" in self.lik_dict and "alpha_n" in self.lik_dict:
self.log_lik_t.set_value(
np.stack([log_lik_fn(np.broadcast_to(m, N)) for m in range(M)])
)

if set(self.lik_dict.keys()) == set([
"lik_n",
"alpha_n",
"beta_n",
"samples",
"alphas",
]):
pass
else:
lik_n: np.ndarray = np.empty((M,), dtype=float)
alpha_n: np.ndarray = np.empty((M,), dtype=float)
beta_n: np.ndarray = np.empty((M,), dtype=float)
samples: np.ndarray = np.empty((N,), dtype=np.int8)
alphas: np.ndarray = np.empty((M, N), dtype=float)

self.lik_dict["lik_n"] = lik_n
self.lik_dict["alpha_n"] = alpha_n
self.lik_dict["beta_n"] = beta_n
self.lik_dict["samples"] = samples
self.lik_dict["alphas"] = alphas

return ffbs_astep(gamma_0, Gammas_t, log_lik_t, self.lik_dict)
return ffbs_astep(gamma_0, Gammas_t, self.log_lik_t.get_value(), self.lik_dict)

@staticmethod
def competence(var):
Expand Down
20 changes: 16 additions & 4 deletions tests/test_step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,19 @@ def test_ffbs_astep():
test_Gammas = np.array([[[0.9, 0.1], [0.1, 0.9]]])
test_gamma_0 = np.r_[0.5, 0.5]

lik_dict = {}
lik_dict["lik_n"] = np.empty((test_gamma_0.shape[-1],), dtype=float)
lik_dict["alpha_n"] = np.empty((test_gamma_0.shape[-1],), dtype=float)

test_log_lik_0 = np.stack(
[np.broadcast_to(0.0, 10000), np.broadcast_to(-np.inf, 10000)]
)

lik_dict = {}
M = test_gamma_0.shape[-1]
N = test_log_lik_0.shape[-1]
lik_dict["lik_n"] = np.empty((M,), dtype=float)
lik_dict["alpha_n"] = np.empty((M,), dtype=float)
lik_dict["beta_n"] = np.empty((M,), dtype=float)
lik_dict["alphas"] = np.empty((M, N), dtype=float)
lik_dict["samples"] = np.empty((N,), dtype=np.int8)

res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_0, lik_dict)
assert np.all(res == 0)

Expand Down Expand Up @@ -85,7 +91,12 @@ def test_ffbs_astep():
test_log_lik[::2] = test_log_lik[::2][:, ::-1]
test_log_lik = test_log_lik.T

N = test_log_lik.shape[-1]

lik_dict["alphas"] = np.empty((M, N), dtype=float)
lik_dict["samples"] = np.empty((N,), dtype=np.int8)
res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik, lik_dict)

assert np.array_equal(res, np.r_[1, 0, 0, 1])


Expand All @@ -106,6 +117,7 @@ def test_FFBSStep():
pi_0_tt = compute_steady_state(P_rv)

S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=y_test.shape[0])
S_rv.tag.test_value = (y_test > 0) * 1

PoissonZeroProcess("Y_t", 9.0, S_rv, observed=y_test)

Expand Down

0 comments on commit f9abf35

Please sign in to comment.