Skip to content

Commit

Permalink
traverse through graph
Browse files Browse the repository at this point in the history
  • Loading branch information
xjing76 committed Mar 26, 2021
1 parent f78526e commit d76747e
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 48 deletions.
5 changes: 1 addition & 4 deletions pymc3_hmm/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def __init__(self, comp_dists, states, *args, **kwargs):
"""
self.states = tt.as_tensor_variable(pm.intX(states))
self._logp_like = None

if len(comp_dists) > 31:
warnings.warn(
Expand Down Expand Up @@ -187,9 +186,7 @@ def logp(self, obs):

obs_tt = tt.as_tensor_variable(obs)

shape_var = tuple(obs_tt.shape.tag.test_value)

logp_val = tt.alloc(-np.inf, *shape_var)
logp_val = tt.alloc(-np.inf, *obs.shape)

for i, dist in enumerate(self.comp_dists):
i_mask = tt.eq(self.states, i)
Expand Down
104 changes: 60 additions & 44 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
from theano.graph.op import get_test_value as test_value
from theano.graph.opt import OpRemove, pre_greedy_local_optimizer
from theano.graph.optdb import Query
from theano.scan.utils import clone
from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.subtensor import AdvancedIncSubtensor1
from theano.tensor.var import TensorConstant

from pymc3_hmm.distributions import DiscreteMarkovChain
from pymc3_hmm.distributions import DiscreteMarkovChain, PoissonZeroProcess
from pymc3_hmm.utils import compute_trans_freqs

big: float = 1e20
Expand Down Expand Up @@ -109,6 +110,44 @@ def ffbs_astep(
return samples


def traverse_graph_and_replace(rv, state= None):
visited = []

graph = rv.logp_elemwiset

if type(rv.distribution) is PoissonZeroProcess:
log_k = graph.owner.inputs[0].owner.inputs[0]
shared_log_k = shared(log_k.eval(), name=f'shared_log_k', borrow=True)
clone(graph, {log_k: shared_log_k})

st = graph.owner.inputs[0].owner.inputs[1].owner.inputs[0].owner.inputs[1].owner.inputs[1].owner.inputs[1]
st_alloc = st.owner.inputs[0].owner.inputs[1]
shared_st = shared(st_alloc.eval() + state, name=f'shared_st_{state}', borrow=True)

clone(graph, {st_alloc: shared_st})

return graph

queue = [graph]
count = 0

while queue:
node = queue.pop(0)
if node not in visited:
visited.append(node)
if node.__str__() =="Alloc.0":
shared_var = shared(node.eval(), name = f'shared_var_{count}', borrow = True)
clone(graph, {node:shared_var})
count += 1
if node.owner is not None:
inputs = node.owner.inputs
for input in inputs:
queue.append(input)

return graph



class FFBSStep(ArrayStep):
r"""Forward-filtering backward-sampling steps.
Expand Down Expand Up @@ -142,68 +181,45 @@ def __init__(self, var, values=None, model=None):
# total log-likelihood values for each state in the sequence.
var_sample = model.test_point[var.name]

self.log_likelihood_values = []
log_likelihood_values = []
for i, dependent_rv in enumerate(self.dependent_rvs):
number_of_state = len(dependent_rv.distribution.comp_dists)
shared_logp = shared(np.zeros((number_of_state, ) +var_sample.shape),
name=f"log_likelihood_values_{i}", borrow=True)

log_p_t = []
for state_i in range(number_of_state) :
## theano.graph.basic.clone_replace
## replace state squence
for state_i in range(number_of_state):
log_p_t.append(traverse_graph_and_replace(dependent_rv, state_i))

logp_t = dependent_rv.logp_elemwiset.clone()
alloc = logp_t.owner.inputs[0].owner.inputs[0].owner.inputs[0]
if alloc.__str__() =="Alloc.0" :
logp_t.owner.inputs[0].owner.inputs[0] = shared_logp[state_i]
log_p_t.append(logp_t)
log_likelihood_values.append(tt.stack(log_p_t))

self.log_likelihood_values.append(log_p_t)


temp = tt.sum(self.log_likelihood_values, axis=0)
temp = tt.sum(log_likelihood_values, axis = 0)

dependents_log_lik = model.fn(temp)

self.gamma_0_fn = model.fn(var.distribution.gamma_0)
self.Gammas_fn = model.fn(var.distribution.Gammas)

self.lik_dict = {}
M = number_of_state
N = var_sample.shape[0]

lik_n: np.ndarray = np.empty((M,), dtype=float)
alpha_n: np.ndarray = np.empty((M,), dtype=float)
beta_n: np.ndarray = np.empty((M,), dtype=float)
samples: np.ndarray = np.empty((N,), dtype=np.int8)
alphas: np.ndarray = np.empty((M, N), dtype=float)

self.lik_dict["lik_n"] = lik_n
self.lik_dict["alpha_n"] = alpha_n
self.lik_dict["beta_n"] = beta_n
self.lik_dict["samples"] = samples
self.lik_dict["alphas"] = alphas

super().__init__([var], [dependents_log_lik], allvars=True)

def astep(self, point, log_lik_fn, inputs):
gamma_0 = self.gamma_0_fn(inputs)
Gammas_t = self.Gammas_fn(inputs)

M = gamma_0.shape[-1]
N = point.shape[-1]

state_seqs = np.broadcast_to(np.arange(M, dtype=int)[..., None], (M, N))
log_lik_t = log_lik_fn(state_seqs)


if set(self.lik_dict.keys()) == set([
"lik_n",
"alpha_n",
"beta_n",
"samples",
"alphas",
]):
pass
else:
lik_n: np.ndarray = np.empty((M,), dtype=float)
alpha_n: np.ndarray = np.empty((M,), dtype=float)
beta_n: np.ndarray = np.empty((M,), dtype=float)
samples: np.ndarray = np.empty((N,), dtype=np.int8)
alphas: np.ndarray = np.empty((M, N), dtype=float)

self.lik_dict["lik_n"] = lik_n
self.lik_dict["alpha_n"] = alpha_n
self.lik_dict["beta_n"] = beta_n
self.lik_dict["samples"] = samples
self.lik_dict["alphas"] = alphas
log_lik_t = log_lik_fn(point)

return ffbs_astep(gamma_0, Gammas_t, log_lik_t, self.lik_dict)

Expand Down

0 comments on commit d76747e

Please sign in to comment.