diff --git a/basicrta/cluster.py b/basicrta/cluster.py index 275acb1..da0bba7 100644 --- a/basicrta/cluster.py +++ b/basicrta/cluster.py @@ -28,11 +28,12 @@ class ProcessProtein(object): :type cutoff: float """ - def __init__(self, niter, prot, cutoff): + def __init__(self, niter, prot, cutoff, gskip): self.residues = {} self.niter = niter self.prot = prot self.cutoff = cutoff + self.gskip = gskip def __getitem__(self, item): return getattr(self, item) @@ -43,6 +44,7 @@ def _single_residue(self, adir, process=False): result = f'{adir}/gibbs_{self.niter}.pkl' g = Gibbs().load(result) if process: + g.gskip = self.gskip g.process_gibbs() except ValueError: result = None diff --git a/basicrta/gibbs.py b/basicrta/gibbs.py index 41bad42..687a1eb 100644 --- a/basicrta/gibbs.py +++ b/basicrta/gibbs.py @@ -231,11 +231,12 @@ def cluster(self, method="GaussianMixture", **kwargs): from scipy import stats clu = getattr(mixture, method) - burnin_ind = self.burnin // (self.g*self.gskip) + burnin_ind = self.burnin // self.g data_len = len(self.times) wcutoff = 10 / data_len - weights, rates = self.mcweights[burnin_ind:], self.mcrates[burnin_ind:] + weights = self.mcweights[burnin_ind::self.gskip] + rates = self.mcrates[burnin_ind::self.gskip] lens = np.array([len(row[row > wcutoff]) for row in weights]) lmode = stats.mode(lens).mode train_param = lmode @@ -258,7 +259,7 @@ def cluster(self, method="GaussianMixture", **kwargs): all_labels = r.predict(np.log(data)) if self.indicator is not None: - indicator = self.indicator[burnin_ind:] + indicator = self.indicator[burnin_ind::self.gskip] else: indicator = self._sample_indicator() @@ -285,13 +286,15 @@ def process_gibbs(self): data_len = len(self.times) wcutoff = 10/data_len burnin_ind = self.burnin//self.g - inds = np.where(self.mcweights[burnin_ind:] > wcutoff) + inds = np.where(self.mcweights[burnin_ind::self.gskip] > wcutoff) indices = (np.arange(self.burnin, self.niter + 1, self.g*self.gskip) [inds[0]] // self.g) - weights, rates = self.mcweights[burnin_ind:], self.mcrates[burnin_ind:] + weights = self.mcweights[burnin_ind::self.gskip] + rates = self.mcrates[burnin_ind::self.gskip] fweights, frates = weights[inds], rates[inds] - lens = [len(row[row > wcutoff]) for row in self.mcweights[burnin_ind:]] + lens = [len(row[row > wcutoff]) for row in + self.mcweights[burnin_ind::self.gskip]] lmode = stats.mode(lens).mode self.cluster(n_init=117, n_components=lmode) @@ -320,8 +323,8 @@ def result_plot(self, remove_noise=False, **kwargs): mixture_and_plot(self, remove_noise=remove_noise, **kwargs) def _sample_indicator(self): - indicator = np.zeros(((self.niter+1)//self.g, self.times.shape[0]), - dtype=np.uint8) + indicator = np.zeros(((self.niter+1)//(self.g*self.gskip), + self.times.shape[0]), dtype=np.uint8) burnin_ind = self.burnin//self.g for i, (w, r) in enumerate(zip(self.mcweights, self.mcrates)): # compute probabilities @@ -332,7 +335,7 @@ def _sample_indicator(self): s = np.argmax(rng.multinomial(1, z), axis=1) indicator[i] = s setattr(self, 'indicator', indicator) - return indicator[burnin_ind:] + return indicator[burnin_ind::self.gskip] def save(self): """ diff --git a/basicrta/util.py b/basicrta/util.py index b06ec8b..180769c 100644 --- a/basicrta/util.py +++ b/basicrta/util.py @@ -704,7 +704,8 @@ def mixture_and_plot(gibbs, scale=2, sparse=1, remove_noise=False, wlim=None, else: wmin, wmax = wcutoff, 2 - weights, rates = gibbs.mcweights[burnin_ind:], gibbs.mcrates[burnin_ind:] + weights = gibbs.mcweights[burnin_ind::gibbs.gskip] + rates = gibbs.mcrates[burnin_ind::gibbs.gskip] lens = np.array([len(row[row > wcutoff]) for row in weights]) lmode = stats.mode(lens).mode train_param = lmode