diff --git a/README.md b/README.md old mode 100644 new mode 100755 index cdb40d8..2a91055 --- a/README.md +++ b/README.md @@ -1 +1,3 @@ -# naturecomm_cscg \ No newline at end of file +# Clone-Structured Cognitive Graphs + +Code for ["Learning cognitive maps as structured graphs for vicarious evaluation"](https://www.biorxiv.org/content/10.1101/864421v4.full) diff --git a/chmm_actions.py b/chmm_actions.py new file mode 100755 index 0000000..afead55 --- /dev/null +++ b/chmm_actions.py @@ -0,0 +1,646 @@ +#*********************************************************************** +# +# VICARIOUS CONFIDENTIAL +# __________________ +# +# [2010] - [2021] Vicarious +# All Rights Reserved. +# +# NOTICE: All information contained herein is, and remains +# the property of Vicarious Incorporated and its suppliers, +# if any. The intellectual and technical concepts contained +# herein are proprietary to Vicarious Incorporated +# and its suppliers and may be covered by U.S. and Foreign Patents, +# patents in process, and are protected by trade secret or copyright law. +# Dissemination of this information or reproduction of this material +# is strictly forbidden unless prior written permission is obtained +# from Vicarious Incorporated. +# +#*********************************************************************** + +from __future__ import print_function +from builtins import range +import numpy as np +import numba as nb +from tqdm import trange +import sys + +def validate_seq(x, a, n_clones=None): + """Validate an input sequence of observations x and actions a""" + assert len(x) == len(a) > 0 + assert len(x.shape) == len(a.shape) == 1, "Flatten your array first" + assert x.dtype == a.dtype == np.int64 + assert 0 <= x.min(), "Number of emissions inconsistent with training sequence" + if n_clones is not None: + assert len(n_clones.shape) == 1, "Flatten your array first" + assert n_clones.dtype == np.int64 + assert all( + [c > 0 for c in n_clones] + ), "You can't provide zero clones for any emission" + n_emissions = n_clones.shape[0] + assert ( + x.max() < n_emissions + ), "Number of emissions inconsistent with training sequence" + + +class CHMM(object): + def __init__(self, n_clones, x, a, pseudocount=0.0, dtype=np.float32): + """Construct a CHMM objct. n_clones is an array where n_clones[i] is the + number of clones assigned to observation i. x and a are the observation sequences + and action sequences, respectively.""" + self.n_clones = n_clones + validate_seq(x, a, self.n_clones) + assert pseudocount >= 0.0, "The pseudocount should be positive" + print("Average number of clones:", n_clones.mean()) + self.pseudocount = pseudocount + self.dtype = dtype + n_states = self.n_clones.sum() + n_actions = a.max() + 1 + self.C = np.random.rand(n_actions, n_states, n_states).astype(dtype) + self.Pi_x = np.ones(n_states) / n_states + self.Pi_a = np.ones(n_actions) / n_actions + self.update_T() + + def update_T(self): + """Update the transition matrix given the accumulated counts matrix.""" + self.T = self.C + self.pseudocount + norm = self.T.sum(2, keepdims=True) + norm[norm == 0] = 1 + self.T /= norm + + # def update_T(self): + # self.T = self.C + self.pseudocount + # norm = self.T.sum(2, keepdims=True) # old model (conditional on actions) + # norm[norm == 0] = 1 + # self.T /= norm + # norm = self.T.sum((0, 2), keepdims=True) # new model (generates actions too) + # norm[norm == 0] = 1 + # self.T /= norm + + def update_E(self, CE): + """Update the emission matrix.""" + E = CE + self.pseudocount + norm = E.sum(1, keepdims=True) + norm[norm == 0] = 1 + E /= norm + return E + + def bps(self, x, a): + """Compute the log likelihood (log base 2) of a sequence of observations and actions.""" + validate_seq(x, a, self.n_clones) + log2_lik = forward(self.T.transpose(0, 2, 1), self.Pi_x, self.n_clones, x, a)[0] + return -log2_lik + + def bpsE(self, E, x, a): + """Compute the log likelihood using an alternate emission matrix.""" + validate_seq(x, a, self.n_clones) + log2_lik = forwardE( + self.T.transpose(0, 2, 1), E, self.Pi_x, self.n_clones, x, a + ) + return -log2_lik + + def bpsV(self, x, a): + validate_seq(x, a, self.n_clones) + log2_lik = forward_mp( + self.T.transpose(0, 2, 1), self.Pi_x, self.n_clones, x, a + )[0] + return -log2_lik + + def decode(self, x, a): + """Compute the MAP assignment of latent variables using max-product message passing.""" + log2_lik, mess_fwd = forward_mp( + self.T.transpose(0, 2, 1), + self.Pi_x, + self.n_clones, + x, + a, + store_messages=True, + ) + states = backtrace(self.T, self.n_clones, x, a, mess_fwd) + return -log2_lik, states + + def decodeE(self, E, x, a): + """Compute the MAP assignment of latent variables using max-product message passing + with an alternative emission matrix.""" + log2_lik, mess_fwd = forwardE_mp( + self.T.transpose(0, 2, 1), + E, + self.Pi_x, + self.n_clones, + x, + a, + store_messages=True, + ) + states = backtraceE(self.T, E, self.n_clones, x, a, mess_fwd) + return -log2_lik, states + + def learn_em_T(self, x, a, n_iter=100, term_early=True): + """Run EM training, keeping E deterministic and fixed, learning T""" + sys.stdout.flush() + convergence = [] + pbar = trange(n_iter, position=0) + log2_lik_old = -np.inf + for it in pbar: + # E + log2_lik, mess_fwd = forward( + self.T.transpose(0, 2, 1), + self.Pi_x, + self.n_clones, + x, + a, + store_messages=True, + ) + mess_bwd = backward(self.T, self.n_clones, x, a) + updateC(self.C, self.T, self.n_clones, mess_fwd, mess_bwd, x, a) + # M + self.update_T() + convergence.append(-log2_lik.mean()) + pbar.set_postfix(train_bps=convergence[-1]) + if log2_lik.mean() <= log2_lik_old: + if term_early: + break + log2_lik_old = log2_lik.mean() + return convergence + + def learn_viterbi_T(self, x, a, n_iter=100): + """Run Viterbi training, keeping E deterministic and fixed, learning T""" + sys.stdout.flush() + convergence = [] + pbar = trange(n_iter, position=0) + log2_lik_old = -np.inf + for it in pbar: + # E + log2_lik, mess_fwd = forward_mp( + self.T.transpose(0, 2, 1), + self.Pi_x, + self.n_clones, + x, + a, + store_messages=True, + ) + states = backtrace(self.T, self.n_clones, x, a, mess_fwd) + self.C[:] = 0 + for t in range(1, len(x)): + aij, i, j = ( + a[t - 1], + states[t - 1], + states[t], + ) # at time t-1 -> t we go from observation i to observation j + self.C[aij, i, j] += 1.0 + # M + self.update_T() + + convergence.append(-log2_lik.mean()) + pbar.set_postfix(train_bps=convergence[-1]) + if log2_lik.mean() <= log2_lik_old: + break + log2_lik_old = log2_lik.mean() + return convergence + + def learn_em_E(self, x, a, n_iter=100, pseudocount_extra=1e-20): + """Run Viterbi training, keeping T fixed, learning E""" + sys.stdout.flush() + n_emissions, n_states = len(self.n_clones), self.n_clones.sum() + CE = np.ones((n_states, n_emissions), self.dtype) + E = self.update_E(CE + pseudocount_extra) + convergence = [] + pbar = trange(n_iter, position=0) + log2_lik_old = -np.inf + for it in pbar: + # E + log2_lik, mess_fwd = forwardE( + self.T.transpose(0, 2, 1), + E, + self.Pi_x, + self.n_clones, + x, + a, + store_messages=True, + ) + mess_bwd = backwardE(self.T, E, self.n_clones, x, a) + updateCE(CE, E, self.n_clones, mess_fwd, mess_bwd, x, a) + # M + E = self.update_E(CE + pseudocount_extra) + convergence.append(-log2_lik.mean()) + pbar.set_postfix(train_bps=convergence[-1]) + if log2_lik.mean() <= log2_lik_old: + break + log2_lik_old = log2_lik.mean() + return convergence, E + + def sample(self, length): + """Sample from the CHMM.""" + assert length > 0 + state_loc = np.hstack(([0], self.n_clones)).cumsum(0) + sample_x = np.zeros(length, dtype=np.int64) + sample_a = np.random.choice(len(self.Pi_a), size=length, p=self.Pi_a) + + # Sample + p_h = self.Pi_x + for t in range(length): + h = np.random.choice(len(p_h), p=p_h) + sample_x[t] = np.digitize(h, state_loc) - 1 + p_h = self.T[sample_a[t], h] + return sample_x, sample_a + + def sample_sym(self, sym, length): + """Sample from the CHMM conditioning on an inital observation.""" + # Prepare structures + assert length > 0 + state_loc = np.hstack(([0], self.n_clones)).cumsum(0) + + seq = [sym] + + alpha = np.ones(self.n_clones[sym]) + alpha /= alpha.sum() + + for _ in range(length): + obs_tm1 = seq[-1] + T_weighted = self.T.sum(0) + + long_alpha = np.dot( + alpha, T_weighted[state_loc[obs_tm1] : state_loc[obs_tm1 + 1], :] + ) + long_alpha /= long_alpha.sum() + idx = np.random.choice(np.arange(self.n_clones.sum()), p=long_alpha) + + sym = np.digitize(idx, state_loc) - 1 + seq.append(sym) + + temp_alpha = long_alpha[state_loc[sym] : state_loc[sym + 1]] + temp_alpha /= temp_alpha.sum() + alpha = temp_alpha + + return seq + + def bridge(self, state1, state2, max_steps=100): + Pi_x = np.zeros(self.n_clones.sum(), dtype=self.dtype) + Pi_x[state1] = 1 + log2_lik, mess_fwd = forward_mp_all( + self.T.transpose(0, 2, 1), Pi_x, self.Pi_a, self.n_clones, state2, max_steps + ) + s_a = backtrace_all(self.T, self.Pi_a, self.n_clones, mess_fwd, state2) + return s_a + +def updateCE(CE, E, n_clones, mess_fwd, mess_bwd, x, a): + timesteps = len(x) + gamma = mess_fwd * mess_bwd + norm = gamma.sum(1, keepdims=True) + norm[norm == 0] = 1 + gamma /= norm + CE[:] = 0 + for t in range(timesteps): + CE[:, x[t]] += gamma[t] + + +def forwardE(T_tr, E, Pi, n_clones, x, a, store_messages=False): + """ Log-probability of a sequence, and optionally, messages""" + assert (n_clones.sum(), len(n_clones)) == E.shape + dtype = T_tr.dtype.type + + # forward pass + t, log2_lik = 0, np.zeros(len(x), dtype) + j = x[t] + message = Pi * E[:, j] + p_obs = message.sum() + assert p_obs > 0 + message /= p_obs + log2_lik[0] = np.log2(p_obs) + if store_messages: + mess_fwd = np.empty((len(x), E.shape[0]), dtype=dtype) + mess_fwd[t] = message + for t in range(1, x.shape[0]): + aij, j = ( + a[t - 1], + x[t], + ) # at time t-1 -> t we go from observation i to observation j + message = T_tr[aij].dot(message) + message *= E[:, j] + p_obs = message.sum() + assert p_obs > 0 + message /= p_obs + log2_lik[t] = np.log2(p_obs) + if store_messages: + mess_fwd[t] = message + if store_messages: + return log2_lik, mess_fwd + else: + return log2_lik + + +def backwardE(T, E, n_clones, x, a): + """Compute backward messages.""" + assert (n_clones.sum(), len(n_clones)) == E.shape + dtype = T.dtype.type + + # backward pass + t = x.shape[0] - 1 + message = np.ones(E.shape[0], dtype) + message /= message.sum() + mess_bwd = np.empty((len(x), E.shape[0]), dtype=dtype) + mess_bwd[t] = message + for t in range(x.shape[0] - 2, -1, -1): + aij, j = ( + a[t], + x[t + 1], + ) # at time t -> t+1 we go from observation i to observation j + message = T[aij].dot(message * E[:, j]) + p_obs = message.sum() + assert p_obs > 0 + message /= p_obs + mess_bwd[t] = message + return mess_bwd + + +@nb.njit +def updateC(C, T, n_clones, mess_fwd, mess_bwd, x, a): + state_loc = np.hstack((np.array([0], dtype=n_clones.dtype), n_clones)).cumsum() + mess_loc = np.hstack((np.array([0], dtype=n_clones.dtype), n_clones[x])).cumsum() + timesteps = len(x) + C[:] = 0 + for t in range(1, timesteps): + aij, i, j = ( + a[t - 1], + x[t - 1], + x[t], + ) # at time t-1 -> t we go from observation i to observation j + (tm1_start, tm1_stop), (t_start, t_stop) = ( + mess_loc[t - 1 : t + 1], + mess_loc[t : t + 2], + ) + (i_start, i_stop), (j_start, j_stop) = ( + state_loc[i : i + 2], + state_loc[j : j + 2], + ) + q = ( + mess_fwd[tm1_start:tm1_stop].reshape(-1, 1) + * T[aij, i_start:i_stop, j_start:j_stop] + * mess_bwd[t_start:t_stop].reshape(1, -1) + ) + q /= q.sum() + C[aij, i_start:i_stop, j_start:j_stop] += q + + +@nb.njit +def forward(T_tr, Pi, n_clones, x, a, store_messages=False): + """ Log-probability of a sequence, and optionally, messages""" + state_loc = np.hstack((np.array([0], dtype=n_clones.dtype), n_clones)).cumsum() + dtype = T_tr.dtype.type + + # forward pass + t, log2_lik = 0, np.zeros(len(x), dtype) + j = x[t] + j_start, j_stop = state_loc[j : j + 2] + message = Pi[j_start:j_stop].copy().astype(dtype) + p_obs = message.sum() + assert p_obs > 0 + message /= p_obs + log2_lik[0] = np.log2(p_obs) + if store_messages: + mess_loc = np.hstack( + (np.array([0], dtype=n_clones.dtype), n_clones[x]) + ).cumsum() + mess_fwd = np.empty(mess_loc[-1], dtype=dtype) + t_start, t_stop = mess_loc[t : t + 2] + mess_fwd[t_start:t_stop] = message + else: + mess_fwd = None + + for t in range(1, x.shape[0]): + aij, i, j = ( + a[t - 1], + x[t - 1], + x[t], + ) # at time t-1 -> t we go from observation i to observation j + (i_start, i_stop), (j_start, j_stop) = ( + state_loc[i : i + 2], + state_loc[j : j + 2], + ) + message = np.ascontiguousarray(T_tr[aij, j_start:j_stop, i_start:i_stop]).dot( + message + ) + p_obs = message.sum() + assert p_obs > 0 + message /= p_obs + log2_lik[t] = np.log2(p_obs) + if store_messages: + t_start, t_stop = mess_loc[t : t + 2] + mess_fwd[t_start:t_stop] = message + return log2_lik, mess_fwd + + +@nb.njit +def backward(T, n_clones, x, a): + """Compute backward messages.""" + state_loc = np.hstack((np.array([0], dtype=n_clones.dtype), n_clones)).cumsum() + dtype = T.dtype.type + + # backward pass + t = x.shape[0] - 1 + i = x[t] + message = np.ones(n_clones[i], dtype) / n_clones[i] + message /= message.sum() + mess_loc = np.hstack((np.array([0], dtype=n_clones.dtype), n_clones[x])).cumsum() + mess_bwd = np.empty(mess_loc[-1], dtype) + t_start, t_stop = mess_loc[t : t + 2] + mess_bwd[t_start:t_stop] = message + for t in range(x.shape[0] - 2, -1, -1): + aij, i, j = ( + a[t], + x[t], + x[t + 1], + ) # at time t -> t+1 we go from observation i to observation j + (i_start, i_stop), (j_start, j_stop) = ( + state_loc[i : i + 2], + state_loc[j : j + 2], + ) + message = np.ascontiguousarray(T[aij, i_start:i_stop, j_start:j_stop]).dot( + message + ) + p_obs = message.sum() + assert p_obs > 0 + message /= p_obs + t_start, t_stop = mess_loc[t : t + 2] + mess_bwd[t_start:t_stop] = message + return mess_bwd + + +@nb.njit +def forward_mp(T_tr, Pi, n_clones, x, a, store_messages=False): + """ Log-probability of a sequence, and optionally, messages""" + state_loc = np.hstack((np.array([0], dtype=n_clones.dtype), n_clones)).cumsum() + dtype = T_tr.dtype.type + + # forward pass + t, log2_lik = 0, np.zeros(len(x), dtype) + j = x[t] + j_start, j_stop = state_loc[j : j + 2] + message = Pi[j_start:j_stop].copy().astype(dtype) + p_obs = message.max() + assert p_obs > 0 + message /= p_obs + log2_lik[0] = np.log2(p_obs) + if store_messages: + mess_loc = np.hstack( + (np.array([0], dtype=n_clones.dtype), n_clones[x]) + ).cumsum() + mess_fwd = np.empty(mess_loc[-1], dtype=dtype) + t_start, t_stop = mess_loc[t : t + 2] + mess_fwd[t_start:t_stop] = message + else: + mess_fwd = None + + for t in range(1, x.shape[0]): + aij, i, j = ( + a[t - 1], + x[t - 1], + x[t], + ) # at time t-1 -> t we go from observation i to observation j + (i_start, i_stop), (j_start, j_stop) = ( + state_loc[i : i + 2], + state_loc[j : j + 2], + ) + new_message = np.zeros(j_stop - j_start, dtype=dtype) + for d in range(len(new_message)): + new_message[d] = (T_tr[aij, j_start + d, i_start:i_stop] * message).max() + message = new_message + p_obs = message.max() + assert p_obs > 0 + message /= p_obs + log2_lik[t] = np.log2(p_obs) + if store_messages: + t_start, t_stop = mess_loc[t : t + 2] + mess_fwd[t_start:t_stop] = message + return log2_lik, mess_fwd + + +@nb.njit +def rargmax(x): + # return x.argmax() # <- favors clustering towards smaller state numbers + return np.random.choice((x == x.max()).nonzero()[0]) + + +@nb.njit +def backtrace(T, n_clones, x, a, mess_fwd): + """Compute backward messages.""" + state_loc = np.hstack((np.array([0], dtype=n_clones.dtype), n_clones)).cumsum() + mess_loc = np.hstack((np.array([0], dtype=n_clones.dtype), n_clones[x])).cumsum() + code = np.zeros(x.shape[0], dtype=np.int64) + + # backward pass + t = x.shape[0] - 1 + i = x[t] + t_start, t_stop = mess_loc[t : t + 2] + belief = mess_fwd[t_start:t_stop] + code[t] = rargmax(belief) + for t in range(x.shape[0] - 2, -1, -1): + aij, i, j = ( + a[t], + x[t], + x[t + 1], + ) # at time t -> t+1 we go from observation i to observation j + (i_start, i_stop), j_start = state_loc[i : i + 2], state_loc[j] + t_start, t_stop = mess_loc[t : t + 2] + belief = ( + mess_fwd[t_start:t_stop] * T[aij, i_start:i_stop, j_start + code[t + 1]] + ) + code[t] = rargmax(belief) + states = state_loc[x] + code + return states + + +def backtraceE(T, E, n_clones, x, a, mess_fwd): + """Compute backward messages.""" + assert (n_clones.sum(), len(n_clones)) == E.shape + states = np.zeros(x.shape[0], dtype=np.int64) + + # backward pass + t = x.shape[0] - 1 + belief = mess_fwd[t] + states[t] = rargmax(belief) + for t in range(x.shape[0] - 2, -1, -1): + aij = a[t] # at time t -> t+1 we go from observation i to observation j + belief = mess_fwd[t] * T[aij, :, states[t + 1]] + states[t] = rargmax(belief) + return states + + +def forwardE_mp(T_tr, E, Pi, n_clones, x, a, store_messages=False): + """ Log-probability of a sequence, and optionally, messages""" + assert (n_clones.sum(), len(n_clones)) == E.shape + dtype = T_tr.dtype.type + + # forward pass + t, log2_lik = 0, np.zeros(len(x), dtype) + j = x[t] + message = Pi * E[:, j] + p_obs = message.max() + assert p_obs > 0 + message /= p_obs + log2_lik[0] = np.log2(p_obs) + if store_messages: + mess_fwd = np.empty((len(x), E.shape[0]), dtype=dtype) + mess_fwd[t] = message + for t in range(1, x.shape[0]): + aij, j = ( + a[t - 1], + x[t], + ) # at time t-1 -> t we go from observation i to observation j + message = (T_tr[aij] * message.reshape(1, -1)).max(1) + message *= E[:, j] + p_obs = message.max() + assert p_obs > 0 + message /= p_obs + log2_lik[t] = np.log2(p_obs) + if store_messages: + mess_fwd[t] = message + if store_messages: + return log2_lik, mess_fwd + else: + return log2_lik + + +def forward_mp_all(T_tr, Pi_x, Pi_a, n_clones, target_state, max_steps): + """ Log-probability of a sequence, and optionally, messages""" + # forward pass + t, log2_lik = 0, [] + message = Pi_x + p_obs = message.max() + assert p_obs > 0 + message /= p_obs + log2_lik.append(np.log2(p_obs)) + mess_fwd = [] + mess_fwd.append(message) + T_tr_maxa = (T_tr * Pi_a.reshape(-1, 1, 1)).max(0) + for t in range(1, max_steps): + message = (T_tr_maxa * message.reshape(1, -1)).max(1) + p_obs = message.max() + assert p_obs > 0 + message /= p_obs + log2_lik.append(np.log2(p_obs)) + mess_fwd.append(message) + if message[target_state] > 0: + break + else: + assert False, "Unable to find a bridging path" + return np.array(log2_lik), np.array(mess_fwd) + + +def backtrace_all(T, Pi_a, n_clones, mess_fwd, target_state): + """Compute backward messages.""" + states = np.zeros(mess_fwd.shape[0], dtype=np.int64) + actions = np.zeros(mess_fwd.shape[0], dtype=np.int64) + n_states = T.shape[1] + # backward pass + t = mess_fwd.shape[0] - 1 + actions[t], states[t] = ( + -1, + target_state, + ) # last actions is irrelevant, use an invalid value + for t in range(mess_fwd.shape[0] - 2, -1, -1): + belief = ( + mess_fwd[t].reshape(1, -1) * T[:, :, states[t + 1]] * Pi_a.reshape(-1, 1) + ) + a_s = rargmax(belief.flatten()) + actions[t], states[t] = a_s // n_states, a_s % n_states + return actions, states diff --git a/intro.ipynb b/intro.ipynb new file mode 100755 index 0000000..1591050 --- /dev/null +++ b/intro.ipynb @@ -0,0 +1,480 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from chmm_actions import CHMM\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a CHMM" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average number of clones: 5.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100/100 [00:03<00:00, 29.41it/s, train_bps=0.998]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.4996067139958269\n" + ] + } + ], + "source": [ + "# Train CHMM on random data\n", + "TIMESTEPS = 1000\n", + "OBS = 2\n", + "x = np.random.randint(OBS, size=(1000,)) # Observations. Replace with your data.\n", + "a = np.zeros(1000, dtype=np.int64) # If there are actions in your domain replace this. If not, keep the vector of zeros.\n", + "n_clones = np.ones(OBS, dtype=np.int64) * 5 # Number of clones specifies the capacity for each observation. \n", + "\n", + "x_test = np.random.randint(OBS, size=(1000,)) # Test observations. Replace with your data.\n", + "a_test = np.zeros(1000, dtype=np.int64)\n", + "\n", + "chmm = CHMM(n_clones=n_clones, pseudocount=1e-10, x=x, a=a) # Initialize the model\n", + "progression = chmm.learn_em_T(x, a, n_iter=100, term_early=False) # Training\n", + "\n", + "nll_per_prediction = chmm.bps(x_test, a_test) # Evaluate negative log-likelihood (base 2 log)\n", + "avg_nll = np.mean(nll_per_prediction)\n", + "avg_prediction_probability = 2**(-avg_nll)\n", + "print(avg_prediction_probability)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Rectangular room datagen" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUUAAAECCAYAAABpKcWJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAM+UlEQVR4nO3df6zddX3H8deLS6HYogXpSNd2K38gjpAI5KbLgjEOh4FB0LmYlE3/2JbUxB8pmYnR/bO5f5bFxJklztlANxYRxvgRHWEiRgySTeC2lkkp3QBxtMMVKIy2SIHb1/64p8u73Nvdb+F8z+crez6SG+695+TcV9vL837Pj3uOkwgAMOeE1gMAYEiIIgAURBEACqIIAAVRBICCKAJAMdgo2r7U9i7bj9r+7AD2bLG91/ZDrbccYXut7bttP2x7h+1NA9i01Pb9th8cbfp8601H2J6y/UPbt7fecoTtJ2z/yPZ22zOt90iS7RW2b7b9iO2dtn+t8Z5zRn8/R95esH11b19viI9TtD0l6d8kXSJpt6QHJF2V5OGGm94j6YCkv0tyXqsdle1VklYl2Wb7VElbJX2w8d+TJS1LcsD2Ekn3StqU5AetNh1h+w8lTUt6a5IrWu+R5qIoaTrJM623HGH7OknfT3KN7ZMkvSXJ841nSfrfNuyR9KtJftLH1xjqkeJ6SY8meTzJy5JulPSBloOS3CNpX8sNr5XkqSTbRu/vl7RT0urGm5LkwOjDJaO35j95ba+RdLmka1pvGTLbb5P0HknXSlKSl4cSxJH3SXqsryBKw43iaklPlo93q/H/7ENne52kCyTd13jKkaup2yXtlXRXkuabJH1J0mckHW6847Ui6du2t9re2HqMpLMkPS3pb0Y3NVxje1nrUcUGSTf0+QWGGkUcB9vLJd0i6eokL7Tek2Q2yfmS1khab7vpzQ22r5C0N8nWljuO4d1JLpR0maRPjG6maelESRdK+kqSCyQdlNT8Nn1JGl2Vv1LSP/T5dYYaxT2S1paP14w+h9cY3W53i6Trk9zaek81utp1t6RLG0+5SNKVo9vvbpR0se2vtZ00J8me0X/3SrpNczcdtbRb0u5ydH+z5iI5BJdJ2pbkv/r8IkON4gOSzrZ91uinwwZJ32y8aXBGd2pcK2lnki+23iNJtlfaXjF6/xTN3Vn2SMtNST6XZE2SdZr7Xvpuko+03CRJtpeN7iDT6Crq+yU1fXRDkp9KetL2OaNPvU9SszvuXuMq9XzVWZo7VB6cJK/a/qSkOyVNSdqSZEfLTbZvkPReSWfY3i3pj5Nc23KT5o6APirpR6Pb8CTpj5Lc0W6SVkm6bnQv4QmSbkoymIfADMyZkm6b+9mmEyV9Pcm32k6SJH1K0vWjA5LHJf1e4z1HfmhcIuljvX+tIT4kBwBaGerVZwBogigCQEEUAaAgigBQEEUAKAYdxYH82tNRhrhJGuYuNnXDpu4msWvQUZQ0xH+YIW6ShrmLTd2wqbv/91EEgInq5cHbZ5w+lXVrl7zhy3n62VmtfPvUGBZJO366ciyXM/uzg5o6ZTxPGpLx/NEkSbMHD2pq2ZCezGS8m07e98pYLufl2Z/ppKlTxnNZb3vj3+PSeL+nDi8dy8Vo9sABTS1fPp4Lk7TsLS+N5XIOPfeSTj7tjf8hDz61X4eef8kLndbLr/mtW7tE99+5dvEzTtC7/vzjrSfM88r4vufe9Nb9/VOtJ8yz5/JVrSfMs/8ds60nLGj9+f/eesJRvvP7x37uFK4+A0BBFAGgIIoAUBBFACiIIgAURBEACqIIAAVRBICCKAJAQRQBoCCKAFAQRQAoiCIAFJ2iaPtS27tsP2r7s32PAoBWFo2i7SlJX5Z0maRzJV1l+9y+hwFAC12OFNdLejTJ40lelnSjpA/0OwsA2ugSxdWSniwf7x59DgDedMZ2R4vtjbZnbM88/ewwn/0XABbTJYp7JNXXFlgz+txRkmxOMp1kelyvqwIAk9Ylig9IOtv2WbZPkrRB0jf7nQUAbSz6wlVJXrX9SUl3SpqStCXJjt6XAUADnV7NL8kdku7oeQsANMdvtABAQRQBoCCKAFAQRQAoiCIAFEQRAAqiCAAFUQSAgigCQEEUAaAgigBQEEUAKDo9IcTxOpRZ/fiVA31c9Ou26p/3t54wz2O/vbz1hAWt2NV6wXyfvnN4z1b3qS0faz1hnsd/66utJyzo3C9/vPWEo7z430uPeRpHigBQEEUAKIgiABREEQAKoggABVEEgIIoAkBBFAGgIIoAUBBFACiIIgAURBEACqIIAAVRBICCKAJAsWgUbW+xvdf2Q5MYBAAtdTlS/FtJl/a8AwAGYdEoJrlH0r4JbAGA5rhNEQCKsUXR9kbbM7Zn9u07PK6LBYCJGlsUk2xOMp1k+vTTOQAF8POJegFA0eUhOTdI+hdJ59jebfsP+p8FAG0s+rrPSa6axBAAGAKuPgNAQRQBoCCKAFAQRQAoiCIAFEQRAAqiCAAFUQSAgigCQEEUAaAgigBQEEUAKIgiABROMvYLPf1XVuY3tnxo7Jf7Rtz/2LrWE+Z5558+13rCz43Dpy1vPWGeZ88b3qbnz2m9YGErdrVecLRHvvEXOvjMk17oNI4UAaAgigBQEEUAKIgiABREEQAKoggABVEEgIIoAkBBFAGgIIoAUBBFACiIIgAURBEACqIIAMWiUbS91vbdth+2vcP2pkkMA4AWTuxwnlclfTrJNtunStpq+64kD/e8DQAmbtEjxSRPJdk2en+/pJ2SVvc9DABaOK7bFG2vk3SBpPt6WQMAjXWOou3lkm6RdHWSFxY4faPtGdszh557aZwbAWBiOkXR9hLNBfH6JLcudJ4km5NMJ5k++bSl49wIABPT5d5nS7pW0s4kX+x/EgC00+VI8SJJH5V0se3to7ff7HkXADSx6ENyktwracGXAgSANxt+owUACqIIAAVRBICCKAJAQRQBoCCKAFAQRQAoiCIAFEQRAAqiCAAFUQSAgigCQEEUAaDo8sJVx+3g/qXa+v1z+rjo1+0XZw63njDPi3+V1hMWtGfbqtYT5vnrD29uPWGeTQ9uaD1hnhX/+NbWExZ00598ofWEo3zwgWeOeRpHigBQEEUAKIgiABREEQAKoggABVEEgIIoAkBBFAGgIIoAUBBFACiIIgAURBEACqIIAAVRBIBi0SjaXmr7ftsP2t5h+/OTGAYALXR5PsVDki5OcsD2Ekn32v6nJD/oeRsATNyiUUwSSQdGHy4ZvQ3z2VEB4A3qdJui7Snb2yXtlXRXkvt6XQUAjXSKYpLZJOdLWiNpve3zXnse2xttz9ieOXzw4JhnAsBkHNe9z0mel3S3pEsXOG1zkukk0ycsWzameQAwWV3ufV5pe8Xo/VMkXSLpkZ53AUATXe59XiXpOttTmovoTUlu73cWALTR5d7nf5V0wQS2AEBz/EYLABREEQAKoggABVEEgIIoAkBBFAGgIIoAUBBFACiIIgAURBEACqIIAAVRBICCKAJA0eWpw47bkhelM2cO93HRr9uym4f3CgpT289qPWFBZz+3q/WEeTa9Y0PrCfOs+bPhHVN85eYvtJ6woE9ceGXrCUf5j+dvPeZpw/tXBYCGiCIAFEQRAAqiCAAFUQSAgigCQEEUAaAgigBQEEUAKIgiABREEQAKoggABVEEgIIoAkBBFAGg6BxF21O2f2j79j4HAUBLx3OkuEnSzr6GAMAQdIqi7TWSLpd0Tb9zAKCtrkeKX5L0GUnHfI0B2xttz9ieeeXQgXFsA4CJWzSKtq+QtDfJ1v/rfEk2J5lOMr3k5OVjGwgAk9TlSPEiSVfafkLSjZIutv21XlcBQCOLRjHJ55KsSbJO0gZJ303ykd6XAUADPE4RAIrjet3nJN+T9L1elgDAAHCkCAAFUQSAgigCQEEUAaAgigBQEEUAKIgiABREEQAKoggABVEEgIIoAkBBFAGgIIoAUBzXs+R09c61T+vev/xqHxf9up33O7/besI8qz+0o/WEBd35n9tbT5hnw48vbj1hnvdet6v1hHnOWjLMZ72ffXZf6wlHSWaPeRpHigBQEEUAKIgiABREEQAKoggABVEEgIIoAkBBFAGgIIoAUBBFACiIIgAURBEACqIIAAVRBICi01OH2X5C0n5Js5JeTTLd5ygAaOV4nk/x15M809sSABgArj4DQNE1ipH0bdtbbW/scxAAtNT16vO7k+yx/QuS7rL9SJJ76hlGsdwoSb+0updXOQCA3nU6UkyyZ/TfvZJuk7R+gfNsTjKdZHrl26fGuxIAJmTRKNpeZvvUI+9Ler+kh/oeBgAtdLmee6ak22wfOf/Xk3yr11UA0MiiUUzyuKR3TWALADTHQ3IAoCCKAFAQRQAoiCIAFEQRAAqiCAAFUQSAgigCQEEUAaAgigBQEEUAKIgiABREEQAKJxn/hdpPS/rJGC7qDElDe7GsIW6ShrmLTd2wqbtx7frlJCsXOqGXKI6L7ZmhvZzqEDdJw9zFpm7Y1N0kdnH1GQAKoggAxdCjuLn1gAUMcZM0zF1s6oZN3fW+a9C3KQLApA39SBEAJoooAkBBFAGgIIoAUBBFACj+BxWhHYRVICOBAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "def rectangular_room_datagen(H=6, W=8, n_emissions=20, length=10000):\n", + " room = np.random.randint(n_emissions, size=(H, W))\n", + " actions = np.random.randint(4, size=length) # 0: left, 1: right, 2: up, 3: down\n", + " x = np.zeros(length, int)\n", + " rc = np.zeros((length, 2), int)\n", + " r, c = np.random.randint(H), np.random.randint(W)\n", + " x[0] = room[r, c]\n", + " rc[0] = r, c\n", + " for i, a in enumerate(actions[:-1]):\n", + " if a == 0 and 0 < c:\n", + " c -= 1\n", + " elif a == 1 and c < W - 1:\n", + " c += 1\n", + " elif a == 2 and 0 < r:\n", + " r -= 1\n", + " elif a == 3 and r < H - 1:\n", + " r += 1\n", + " x[i + 1] = room[r, c]\n", + " x[i + 1] = room[r, c]\n", + " rc[i + 1] = r, c\n", + " return room, actions, x, rc\n", + "\n", + "n_emissions=20\n", + "room, a, x, rc = rectangular_room_datagen(n_emissions=n_emissions)\n", + "plt.matshow(room)\n", + "\n", + "n_clones = np.ones(n_emissions, dtype=np.int64) * 10\n", + "chmm = CHMM(n_clones=n_clones, pseudocount=1e-10, x=x, a=a) # Initialize the model\n", + "progression = chmm.learn_em_T(x, a, n_iter=100, term_early=False) # Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Empty rectangular room datagen" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average number of clones: 40.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100/100 [00:10<00:00, 9.60it/s, train_bps=0.0531]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUUAAAECCAYAAABpKcWJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAALgElEQVR4nO3d/6uedR3H8ders2PqtNnSRDZJSRFCSOWwEEVKsWaJ9aOCBhGcggylQLRfpP4AsR9KGGoZfssyI8z8Ahom5JezNdM5CxPFTWvacroE55mvfjj3ibdj61zT+7o/F+35gMPOl5v7vNj0ea7ruu9zjpMIALDgQ60HAMCQEEUAKIgiABREEQAKoggABVEEgGKwUbS91vZfbD9n+8oB7LnR9jbbT7fessj2sbYfsv2M7U22LxvApoNtP277ydGm77fetMj2lO0/2b679ZZFtl+w/ZTtjbbnWu+RJNtH2P6l7Wdtb7Z9euM9J43+fhZf3rB9eW+fb4jPU7Q9Jemvks6VtEXSE5IuSvJMw01nSdop6WdJTm61o7J9jKRjkmywfbik9ZK+0vjvyZKWJ9lpe1rSI5IuS/Joq02LbH9H0oykjyQ5v/UeaSGKkmaSvNZ6yyLbN0n6Q5LrbR8k6dAkrzeeJem/bdgq6TNJXuzjcwz1SHGNpOeSPJ9kl6TbJX255aAkD0va3nLDnpK8kmTD6PU3JW2WtKrxpiTZOXpzevTS/Cuv7dWSviTp+tZbhsz2CklnSbpBkpLsGkoQR86R9Le+gigNN4qrJL1U3t6ixv+zD53t4ySdKumxxlMWT1M3Stom6YEkzTdJulbSFZLebbxjT5F0v+31tmdbj5F0vKRXJf1kdKnhetvLW48qLpR0W5+fYKhRxH6wfZikOyVdnuSN1nuS7E5yiqTVktbYbnq5wfb5krYlWd9yxz6cmeQ0SedJ+tboMk1LyySdJum6JKdK+rek5tf0JWl0Kn+BpF/0+XmGGsWtko4tb68evQ97GF23u1PSLUl+1XpPNTrtekjS2sZTzpB0wej63e2SzrZ9c9tJC5JsHf25TdJdWrh01NIWSVvK0f0vtRDJIThP0oYk/+jzkww1ik9IOtH28aOvDhdK+k3jTYMzelDjBkmbk1zTeo8k2T7K9hGj1w/RwoNlz7bclOSqJKuTHKeF/5YeTHJxy02SZHv56AEyjU5RPy+p6bMbkvxd0ku2Txq96xxJzR6428NF6vnUWVo4VB6cJPO2L5V0n6QpSTcm2dRyk+3bJH1W0pG2t0i6OskNLTdp4QjoEklPja7hSdL3ktzTbpKOkXTT6FHCD0m6I8lgngIzMEdLumvha5uWSbo1yb1tJ0mSvi3pltEByfOSvtZ4z+IXjXMlfaP3zzXEp+QAQCtDPX0GgCaIIgAURBEACqIIAAVRBIBi0FEcyLc9vccQN0nD3MWmbtjU3SR2DTqKkob4DzPETdIwd7GpGzZ1d8BHEQAmqpcnbx+07NAcMr3iA9/Prt1v6aCpQ8ewSNq1ymO5n/k33tKyj4xn0zgNcRebujkQNk39c2os9/PO2zs1/eHDPvD9vP3Wdr3z9r/3GoVevs3vkOkVOv2Er/dx1+/biz8Y5Hc0AgeElTd/8JCN05MP/nCfH+P0GQAKoggABVEEgIIoAkBBFAGgIIoAUBBFACiIIgAURBEACqIIAAVRBICCKAJAQRQBoOgURdtrbf/F9nO2r+x7FAC0smQUbU9J+pGk8yR9StJFtj/V9zAAaKHLkeIaSc8leT7JLkm3S/pyv7MAoI0uUVwl6aXy9pbR+wDg/87YHmixPWt7zvbcrt1vjetuAWCiukRxq6Rjy9urR+97jyTrkswkmRnX71UBgEnrEsUnJJ1o+3jbB0m6UNJv+p0FAG0s+duckszbvlTSfZKmJN2YZFPvywCggU6/4i7JPZLu6XkLADTHd7QAQEEUAaAgigBQEEUAKIgiABREEQAKoggABVEEgIIoAkBBFAGgIIoAUBBFACg6/UCI/XXUCTs0++vf9nHX79t1J57QegJwwLrv5Y2tJ7zHmi+8us+PcaQIAAVRBICCKAJAQRQBoCCKAFAQRQAoiCIAFEQRAAqiCAAFUQSAgigCQEEUAaAgigBQEEUAKIgiABRLRtH2jba32X56EoMAoKUuR4o/lbS25x0AMAhLRjHJw5K2T2ALADTHNUUAKMYWRduztudsz+3YPj+uuwWAiRpbFJOsSzKTZGbFyl5+HxYA9I7TZwAoujwl5zZJf5R0ku0ttr/e/ywAaGPJ89wkF01iCAAMAafPAFAQRQAoiCIAFEQRAAqiCAAFUQSAgigCQEEUAaAgigBQEEUAKIgiABREEQAKoggABVEEgIIoAkBBFAGgIIoAUBBFACiIIgAURBEACqIIAAVRBICCKAJAQRQBoCCKAFAQRQAoiCIAFEQRAIolo2j7WNsP2X7G9ibbl01iGAC0sKzDbeYlfTfJBtuHS1pv+4Ekz/S8DQAmbskjxSSvJNkwev1NSZslrep7GAC0sF/XFG0fJ+lUSY/1sgYAGuscRduHSbpT0uVJ3tjLx2dtz9me27F9fpwbAWBiOkXR9rQWgnhLkl/t7TZJ1iWZSTKzYmWXS5UAMDxdHn22pBskbU5yTf+TAKCdLkeKZ0i6RNLZtjeOXr7Y8y4AaGLJ89wkj0jyBLYAQHN8RwsAFEQRAAqiCAAFUQSAgigCQEEUAaAgigBQEEUAKIgiABREEQAKoggABVEEgIIoAkBBFAGgIIoAUBBFACiIIgAURBEACqIIAAVRBICCKAJAQRQBoCCKAFAQRQAoiCIAFEQRAAqiCAAFUQSAYsko2j7Y9uO2n7S9yfb3JzEMAFpY1uE2b0s6O8lO29OSHrH9uySP9rwNACZuySgmiaSdozenRy/pcxQAtNLpmqLtKdsbJW2T9ECSx3pdBQCNdIpikt1JTpG0WtIa2yfveRvbs7bnbM/t2D4/5pkAMBn79ehzktclPSRp7V4+ti7JTJKZFSu7XKoEgOHp8ujzUbaPGL1+iKRzJT3b8y4AaKLLId0xkm6yPaWFiN6R5O5+ZwFAG10eff6zpFMnsAUAmuM7WgCgIIoAUBBFACiIIgAURBEACqIIAAVRBICCKAJAQRQBoCCKAFAQRQAoiCIAFEQRAAqiCAAFUQSAgigCQEEUAaAgigBQEEUAKIgiABREEQAKoggABVEEgIIoAkBBFAGgIIoAUBBFACiIIgAURBEAis5RtD1l+0+27+5zEAC0tD9HipdJ2tzXEAAYgk5RtL1a0pckXd/vHABoq+uR4rWSrpD07r5uYHvW9pztuR3b58exDQAmbsko2j5f0rYk6//X7ZKsSzKTZGbFymVjGwgAk9TlSPEMSRfYfkHS7ZLOtn1zr6sAoJElo5jkqiSrkxwn6UJJDya5uPdlANAAz1MEgGK/Lv4l+b2k3/eyBAAGgCNFACiIIgAURBEACqIIAAVRBICCKAJAQRQBoCCKAFAQRQAoiCIAFEQRAAqiCAAFUQSAopcfkb311Y/p6h9/tY+7ft9Oe/Sp1hOAA9Ynf/7N1hPe4+V/XbvPj3GkCAAFUQSAgigCQEEUAaAgigBQEEUAKIgiABREEQAKoggABVEEgIIoAkBBFAGgIIoAUBBFACg6/egw2y9IelPSbknzSWb6HAUArezPz1P8XJLXelsCAAPA6TMAFF2jGEn3215ve7bPQQDQUtfT5zOTbLX9cUkP2H42ycP1BqNYzkrS9OEfHfNMAJiMTkeKSbaO/twm6S5Ja/Zym3VJZpLMTB26fLwrAWBCloyi7eW2D198XdLnJT3d9zAAaKHL6fPRku6yvXj7W5Pc2+sqAGhkySgmeV7SpyewBQCa4yk5AFAQRQAoiCIAFEQRAAqiCAAFUQSAgigCQEEUAaAgigBQEEUAKIgiABREEQAKoggAhZOM/07tVyW9OIa7OlLS0H5Z1hA3ScPcxaZu2NTduHZ9IslRe/tAL1EcF9tzQ/t1qkPcJA1zF5u6YVN3k9jF6TMAFEQRAIqhR3Fd6wF7McRN0jB3sakbNnXX+65BX1MEgEkb+pEiAEwUUQSAgigCQEEUAaAgigBQ/Aef7sxGSsfpLgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "def empty_rectangular_room_datagen(H=6, W=8, length=10000):\n", + " room = np.zeros((H, W),dtype=np.int64)\n", + " room[:] = 0\n", + " room[0] = 5\n", + " room[-1] = 6\n", + " room[:, 0] = 7\n", + " room[:, -1] = 8\n", + " room[0, 0] = 1\n", + " room[0, -1] = 2\n", + " room[-1, 0] = 3\n", + " room[-1, -1] = 4\n", + " actions = np.random.randint(4, size=length) # 0: left, 1: right, 2: up, 3: down\n", + " x = np.zeros(length, int)\n", + " rc = np.zeros((length, 2), int)\n", + " r, c = np.random.randint(H), np.random.randint(W)\n", + " x[0] = room[r, c]\n", + " rc[0] = r, c\n", + " for i, a in enumerate(actions[:-1]):\n", + " if a == 0 and 0 < c:\n", + " c -= 1\n", + " elif a == 1 and c < W - 1:\n", + " c += 1\n", + " elif a == 2 and 0 < r:\n", + " r -= 1\n", + " elif a == 3 and r < H - 1:\n", + " r += 1\n", + " x[i + 1] = room[r, c]\n", + " x[i + 1] = room[r, c]\n", + " rc[i + 1] = r, c\n", + " return room, actions, x, rc\n", + "\n", + "\n", + "room, a, x, rc = empty_rectangular_room_datagen()\n", + "plt.matshow(room)\n", + "\n", + "n_clones = np.ones(9, dtype=np.int64) * 40\n", + "chmm = CHMM(n_clones=n_clones, pseudocount=1e-10, x=x, a=a) # Initialize the model\n", + "progression = chmm.learn_em_T(x, a, n_iter=100, term_early=False) # Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3 Pentagon cliques" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average number of clones: 40.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100/100 [00:10<00:00, 9.94it/s, train_bps=1.77]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAECCAYAAAD+eGJTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAANKUlEQVR4nO3db8yddX3H8fdnLVCLf/grCu1WYhgLISCkAdTFLdYJIqE+2APMWGCa8GSbaEgISDKzZyYa/yRbNA2gZBJ8gDgJUUtXNWbJrINS/pYBU0YLxVabidFMWvzuwTnV0txtb851nevc9/17v5LmPudc55zv79x/Pv1d17l+55uqQlK7/mDWA5A0W4aA1DhDQGqcISA1zhCQGmcISI1bECGQ5LIk/5XkmSQ3DVRzdZLvJXkiyeNJrh+i7rj2siQPJblvwJonJLk7yZNJtid5x0B1Pz7+/j6W5K4kK6ZU5/Yku5M8dtBtJyXZlOTp8dcTB6r76fH3+ZEk30hywhB1D9p2Q5JKcsp8nmvmIZBkGfDPwPuBc4APJTlngNL7gRuq6hzgEuBvB6oLcD2wfaBaB3wB+E5V/Qlw/hD1k5wBfBRYW1XnAsuAq6ZU7ivAZYfcdhOwuarOAjaPrw9RdxNwblWdBzwF3DxQXZKsBt4HPDffJ5p5CAAXAc9U1Y+r6mXga8D6aRetql1VtXV8+ZeM/ijOmHbdJKuADwC3TrvWQTXfBLwbuA2gql6uqv8dqPxy4HVJlgMrgRemUaSqfgDsPeTm9cAd48t3AB8com5V3V9V+8dXfwisGqLu2OeAG4F5nwW4EELgDGDHQdd3MsAf48GSrAEuALYMUO7zjH5Ivx2g1gFnAnuAL493Q25Ncvy0i1bV88BnGP2vtAv4RVXdP+26BzmtqnaNL78InDZg7QM+DHx7iEJJ1gPPV9XDr+VxCyEEZirJ64GvAx+rqpemXOsKYHdVPTjNOnNYDlwIfLGqLgB+xXSmxq8y3gdfzyiETgeOT3L1tOvOpUbnxw96jnySWxjtdt45QK2VwCeAf3itj10IIfA8sPqg66vGt01dkmMYBcCdVXXPACXfBVyZ5FlGuz3vSfLVAeruBHZW1YGZzt2MQmHa3gv8pKr2VNU+4B7gnQPUPeCnSd4KMP66e6jCSa4FrgD+qoZZoPM2RmH78Pj3axWwNclbjvbAhRAC/wmcleTMJMcyOnB077SLJgmjfeTtVfXZadcDqKqbq2pVVa1h9Dq/W1VT/5+xql4EdiQ5e3zTOuCJaddltBtwSZKV4+/3OoY9IHovcM348jXAN4comuQyRrt8V1bVr4eoWVWPVtWbq2rN+PdrJ3Dh+Gd/1AfP/B9wOaOjqP8N3DJQzT9lND18BNg2/nf5gK/5z4H7Bqz3duCB8ev9V+DEger+I/Ak8BjwL8BxU6pzF6PjDvvGfwAfAU5m9K7A08C/AScNVPcZRse5DvxefWmIuodsfxY4ZT7PlfEDJDVqIewOSJohQ0BqnCEgNc4QkBpnCEiNWzAhkOQ66y7Nui291sVYd8GEADCTb5x1l2xN687TQgoBSTMw6MlCx+a4WsHci9f28RuO4bjDPvaPz5vO2Zd7fv4Kp568bCrP/dQjKw+77Wivd1pmUbel1zrtukf6OzjS7/KzO/bxs72vZK5ty/sZ2vys4HguzrqJHrtx47Z+BzOAS09/+6yHoCVm0r+Diy7dcdht7g5IjTMEpMZ1CoFZfECopH5NHAIz/IBQST3qMhOYyQeESupXlxCY+QeESupu6m8Rjk9lvA5gBYd/31zSbHSZCczrA0KrakNVra2qtbM4cUPSkXUJgZl8QKikfk28O1BV+5P8HbCRUXup26vq8d5GJmkQnY4JVNW3gG/1NBZJM+AZg1LjBl1FuPb8FfWjjauPfsc5dFmMs/GFbRM/dlZcfKQ+banNvFR751xF6ExAapwhIDXOEJAaZwhIjTMEpMYZAlLjDAGpcYaA1DhDQGqcISA1zhCQGmcISI0zBKTGGQJS4wbtRdhFl+XArS1Dll4LZwJS4wwBqXGGgNS4Lr0IVyf5XpInkjye5Po+ByZpGF0ODO4HbqiqrUneADyYZFNVPdHT2CQNYOKZQFXtqqqt48u/BLZjL0Jp0enlmECSNcAFwJY+nk/ScDqHQJLXA18HPlZVL82x/bokDyR5YM/PX+laTlLPOoVAkmMYBcCdVXXPXPc5uCHpqScv61JO0hR0eXcgwG3A9qr6bH9DkjSkLjOBdwF/Dbwnybbxv8t7GpekgXTpSvzvwJxtjSQtHp4xKDXOEJAat2iWEnfhMmTp8JwJSI0zBKTGGQJS4wwBqXGGgNQ4Q0BqnCEgNc4QkBpnCEiNMwSkxhkCUuMMAalxhoDUOENAalyqarBib8xJdXHWTfTYxbgsd1bLkLvU1dK0pTbzUu2d85PAnAlIjTMEpMYZAlLj+uhAtCzJQ0nu62NAkobVx0zgekbNSCUtQl3bkK0CPgDc2s9wJA2t60zg88CNwG+7D0XSLHTpRXgFsLuqHjzK/X7XlXgfv5m0nKQp6dqL8MokzwJfY9ST8KuH3ungrsTHcFyHcpKmYeIQqKqbq2pVVa0BrgK+W1VX9zYySYPwPAGpcb20Iauq7wPf7+O5JA3LmYDUOENAatyi6Uq8GJfH2g15fhbjz3YpcSYgNc4QkBpnCEiNMwSkxhkCUuMMAalxhoDUOENAapwhIDXOEJAaZwhIjTMEpMYZAlLjDAGpcYtmKfFitBi7Es9iGbIdmGfLmYDUOENAapwhIDWuay/CE5LcneTJJNuTvKOvgUkaRtcDg18AvlNVf5nkWGBlD2OSNKCJQyDJm4B3A9cCVNXLwMv9DEvSULrsDpwJ7AG+nOShJLcmOb6ncUkaSJcQWA5cCHyxqi4AfgXcdOid7EosLWxdQmAnsLOqtoyv380oFF7FrsTSwtalK/GLwI4kZ49vWgc80cuoJA2m67sDfw/cOX5n4MfA33QfkqQhdQqBqtoGrO1nKJJmwTMGpcYZAlLjXEq8BM1iae5i64Ss33MmIDXOEJAaZwhIjTMEpMYZAlLjDAGpcYaA1DhDQGqcISA1zhCQGmcISI0zBKTGGQJS41xFqFeZdDXgYmuCqt9zJiA1zhCQGmcISI0zBKTGde1K/PEkjyd5LMldSVb0NTBJw5g4BJKcAXwUWFtV5wLLgKv6GpikYXTdHVgOvC7JckZtyV/oPiRJQ+rShux54DPAc8Au4BdVdX9fA5M0jC67AycC6xm1KD8dOD7J1XPcz67E0gLWZXfgvcBPqmpPVe0D7gHeeeid7EosLWxdQuA54JIkK5OEUVfi7f0MS9JQuhwT2ALcDWwFHh0/14aexiVpIF27En8S+GRPY5E0A54xKDXOpcQL1GJbmjuLJqizrLuUOBOQGmcISI0zBKTGGQJS4wwBqXGGgNQ4Q0BqnCEgNc4QkBpnCEiNMwSkxhkCUuMMAalxhoDUOJcSL0GLbYnsYlyGPCvT+Nk6E5AaZwhIjTMEpMYdNQSS3J5kd5LHDrrtpCSbkjw9/nridIcpaVrmMxP4CnDZIbfdBGyuqrOAzePrkhaho4ZAVf0A2HvIzeuBO8aX7wA+2O+wJA1l0mMCp1XVrvHlF4HTehqPpIF1PjBYVQXU4bbbkFRa2CYNgZ8meSvA+Ovuw93RhqTSwjZpCNwLXDO+fA3wzX6GI2lo83mL8C7gP4Czk+xM8hHgU8BfJHmaUYvyT013mJKm5ahrB6rqQ4fZtK7nsUiaAc8YlBpnCEiNy+gdvmG8MSfVxXEvQv1xGfL8XHTpDh54+P8y1zZnAlLjDAGpcYaA1DhDQGqcISA1zhCQGmcISI0zBKTGGQJS4wwBqXGGgNQ4Q0BqnCEgNc4QkBpnV2I1y2XII84EpMYZAlLjDAGpcZN2Jf50kieTPJLkG0lOmOooJU3NpF2JNwHnVtV5wFPAzT2PS9JAJupKXFX3V9X+8dUfAqumMDZJA+jjmMCHgW/38DySZqDTeQJJbgH2A3ce4T7XAdcBrGBll3KSpmDiEEhyLXAFsK6O0LygqjYAG2DUd2DSepKmY6IQSHIZcCPwZ1X1636HJGlIk3Yl/ifgDcCmJNuSfGnK45Q0JZN2Jb5tCmORNAOeMSg1zhCQGudSYmkCS2kZsjMBqXGGgNQ4Q0BqnCEgNc4QkBpnCEiNMwSkxhkCUuMMAalxhoDUOENAapwhIDXOEJAaZwhIjcsRPiO0d2vPX1E/2rh6osd2WX4pzWWhdQeej0n/DrbUZl6qvZlrmzMBqXGGgNQ4Q0Bq3ERdiQ/adkOSSnLKdIYnadom7UpMktXA+4Dneh6TpAFN1JV47HOMuhDZWkxaxCY6JpBkPfB8VT3c83gkDew1f+R4kpXAJxjtCszn/r/rSvyHZ/gJ59JCM8lM4G3AmcDDSZ4FVgFbk7xlrjtX1YaqWltVa089ednkI5U0Fa/5v+aqehR484Hr4yBYW1U/63FckgYyaVdiSUvEpF2JD96+prfRSBqcZwxKjTMEpMYNupQ4yR7gfw6z+RRgFgcXrbs0a1r31f6oqk6da8OgIXAkSR6oqrXWXXp1W3qti7GuuwNS4wwBqXELKQQ2WHfJ1m3ptS66ugvmmICk2VhIMwFJM2AISI0zBKTGGQJS4wwBqXH/DzzDLhxYwEVFAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "T = np.zeros((15,15))\n", + "# Connect cliques\n", + "for i in range(0,4+1):\n", + " for j in range(0,4+1):\n", + " if i!=j:\n", + " T[i,j] = 1.0\n", + "for i in range(5,9+1):\n", + " for j in range(5,9+1):\n", + " if i!=j:\n", + " T[i,j] = 1.0\n", + "for i in range(10,14+1):\n", + " for j in range(10,14+1):\n", + " if i!=j:\n", + " T[i,j] = 1.0\n", + "# Disconnect in clique connector nodes\n", + "T[0,4] = 0.0\n", + "T[4,0] = 0.0\n", + "T[5,9] = 0.0\n", + "T[9,5] = 0.0\n", + "T[10,14] = 0.0\n", + "T[14,10] = 0.0\n", + "# Connect cross clique connector nodes\n", + "T[4,5] = 1.0\n", + "T[5,4] = 1.0\n", + "T[9,10] = 1.0\n", + "T[10,9] = 1.0\n", + "T[14,0] = 1.0\n", + "T[0,14] = 1.0\n", + "plt.matshow(T)\n", + "\n", + "# Draw data\n", + "states = [0]\n", + "for _ in range(10000):\n", + " prev_state = states[-1]\n", + " \n", + " possible_next_states = np.where(T[prev_state,:])[0]\n", + " next_state = np.random.choice(possible_next_states)\n", + " states.append(next_state)\n", + "states = np.array(states)\n", + "\n", + "state_to_obs = np.array([1,2,3,4,5,6,1,4,5,2,8,2,3,5,7], dtype=int) - 1 # Aliasing version\n", + "\n", + "# Create observation data\n", + "x = state_to_obs[states]\n", + "a = np.zeros(len(x),dtype=int)\n", + "\n", + "n_clones = np.ones(8, dtype=np.int64) * 40\n", + "chmm = CHMM(n_clones=n_clones, pseudocount=1e-10, x=x, a=a) # Initialize the model\n", + "progression = chmm.learn_em_T(x, a, n_iter=100, term_early=False) # Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5x5 mazes" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'np' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mroom\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mactions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrc\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mroom\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfivebyfive_room_datagen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroom\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mfivebyfive_room_datagen\u001b[0;34m(length)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfivebyfive_room_datagen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlength\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mroom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpermutation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m25\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mH\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mW\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mactions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlength\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 0: left, 1: right, 2: up, 3: down\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlength\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined" + ] + } + ], + "source": [ + "def fivebyfive_room_datagen(length=10000):\n", + " room = np.random.permutation(25).reshape(5,5)\n", + " H,W = 5,5\n", + " actions = np.random.randint(4, size=length) # 0: left, 1: right, 2: up, 3: down\n", + " x = np.zeros(length, int)\n", + " rc = np.zeros((length, 2), int)\n", + " r, c = np.random.randint(H), np.random.randint(W)\n", + " x[0] = room[r, c]\n", + " rc[0] = r, c\n", + " for i, a in enumerate(actions[:-1]):\n", + " if a == 0 and 0 < c:\n", + " c -= 1\n", + " elif a == 1 and c < W - 1:\n", + " c += 1\n", + " elif a == 2 and 0 < r:\n", + " r -= 1\n", + " elif a == 3 and r < H - 1:\n", + " r += 1\n", + " x[i + 1] = room[r, c]\n", + " x[i + 1] = room[r, c]\n", + " rc[i + 1] = r, c\n", + " return room, actions, x, rc\n", + "\n", + "room, a, x, rc = fivebyfive_room_datagen()\n", + "plt.matshow(room)\n", + "\n", + "n_clones = np.ones(25, dtype=np.int64) * 2\n", + "chmm = CHMM(n_clones=n_clones, pseudocount=1e-10, x=x, a=a) # Initialize the model\n", + "progression = chmm.learn_em_T(x, a, n_iter=100, term_early=False) # Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Two Rooms Stitched Together" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average number of clones: 18.8125\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100/100 [00:17<00:00, 5.70it/s, train_bps=0.000256]\n" + ] + } + ], + "source": [ + "room2 = np.array([[11,11, 4,11,12, 3],\n", + " [ 1, 9,11,12,14,12],\n", + " [ 9, 4, 2, 7,14, 0],\n", + " [11, 2, 8, 9, 0, 0],\n", + " [ 2, 9, 6, 8,13, 8],\n", + " [14, 2,13, 4, 5, 0],\n", + " [14, 9, 7,13, 4,14],\n", + " [ 0, 3,13, 3, 1,11]])\n", + "\n", + "\n", + "room1 = np.array([[ 2,13, 1, 0, 4,12],\n", + " [10, 0, 11,12, 3, 7],\n", + " [ 4, 9, 12,14,12, 5],\n", + " [ 8, 4, 7, 14, 0, 5],\n", + " [ 1,13, 2, 7,10, 4],\n", + " [11,12, 3, 8,14, 3],\n", + " [12,14,12, 5, 1, 1],\n", + " [ 7,14, 0, 3, 9, 5]])\n", + "\n", + "def datagen(room, H=6, W=8, length=10000):\n", + " actions = np.random.randint(4, size=length) # 0: left, 1: right, 2: up, 3: down\n", + " x = np.zeros(length, int)\n", + " rc = np.zeros((length, 2), int)\n", + " r, c = np.random.randint(H), np.random.randint(W)\n", + " x[0] = room[r, c]\n", + " rc[0] = r, c\n", + " for i in range(len(actions)-1):\n", + " a = actions[i]\n", + " done = False\n", + " while not done:\n", + " done = True\n", + " if a == 0 and 0 < c:\n", + " c -= 1\n", + " elif a == 1 and c < W - 1:\n", + " c += 1\n", + " elif a == 2 and 0 < r:\n", + " r -= 1\n", + " elif a == 3 and r < H - 1:\n", + " r += 1\n", + " else:\n", + " a = np.random.randint(4)\n", + " done = False\n", + " actions[i] = a\n", + " x[i + 1] = room[r, c]\n", + " rc[i + 1] = r, c\n", + " return actions, x, rc\n", + "a1, x1, rc1 = datagen(room1, H=8, W=6, length=20000)\n", + "a2, x2, rc2 = datagen(room2, H=8, W=6, length=20000)\n", + "\n", + "x = np.hstack((0, x1+1, 0, x2+1))\n", + "a = np.hstack((4, a1[:-1], 4, 4, a2))\n", + "\n", + "n_emissions = x.max()+1\n", + "\n", + "n_clones = 20*np.ones(n_emissions,int)\n", + "n_clones[0] = 1\n", + "chmm = CHMM(n_clones=n_clones, pseudocount=1e-10, x=x, a=a) # Initialize the model\n", + "progression = chmm.learn_em_T(x, a, n_iter=100, term_early=False) # Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/requirements.txt b/requirements.txt new file mode 100755 index 0000000..aea8883 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +numpy +matplotlib +numba +tqdm +jupyter +jupytext +scipy