Skip to content

Commit

Permalink
addedd option for gskip
Browse files Browse the repository at this point in the history
  • Loading branch information
rsexton2 committed Feb 8, 2025
1 parent ea4e514 commit a093ed4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
4 changes: 3 additions & 1 deletion basicrta/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ class ProcessProtein(object):
:type cutoff: float
"""

def __init__(self, niter, prot, cutoff):
def __init__(self, niter, prot, cutoff, gskip):
self.residues = {}
self.niter = niter
self.prot = prot
self.cutoff = cutoff
self.gskip = gskip

def __getitem__(self, item):
return getattr(self, item)
Expand All @@ -43,6 +44,7 @@ def _single_residue(self, adir, process=False):
result = f'{adir}/gibbs_{self.niter}.pkl'
g = Gibbs().load(result)
if process:
g.gskip = self.gskip
g.process_gibbs()
except ValueError:
result = None
Expand Down
21 changes: 12 additions & 9 deletions basicrta/gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,12 @@ def cluster(self, method="GaussianMixture", **kwargs):
from scipy import stats

clu = getattr(mixture, method)
burnin_ind = self.burnin // (self.g*self.gskip)
burnin_ind = self.burnin // self.g
data_len = len(self.times)
wcutoff = 10 / data_len

weights, rates = self.mcweights[burnin_ind:], self.mcrates[burnin_ind:]
weights = self.mcweights[burnin_ind::self.gskip]
rates = self.mcrates[burnin_ind::self.gskip]
lens = np.array([len(row[row > wcutoff]) for row in weights])
lmode = stats.mode(lens).mode
train_param = lmode
Expand All @@ -258,7 +259,7 @@ def cluster(self, method="GaussianMixture", **kwargs):
all_labels = r.predict(np.log(data))

if self.indicator is not None:
indicator = self.indicator[burnin_ind:]
indicator = self.indicator[burnin_ind::self.gskip]
else:
indicator = self._sample_indicator()

Expand All @@ -285,13 +286,15 @@ def process_gibbs(self):
data_len = len(self.times)
wcutoff = 10/data_len
burnin_ind = self.burnin//self.g
inds = np.where(self.mcweights[burnin_ind:] > wcutoff)
inds = np.where(self.mcweights[burnin_ind::self.gskip] > wcutoff)
indices = (np.arange(self.burnin, self.niter + 1, self.g*self.gskip)
[inds[0]] // self.g)
weights, rates = self.mcweights[burnin_ind:], self.mcrates[burnin_ind:]
weights = self.mcweights[burnin_ind::self.gskip]
rates = self.mcrates[burnin_ind::self.gskip]
fweights, frates = weights[inds], rates[inds]

lens = [len(row[row > wcutoff]) for row in self.mcweights[burnin_ind:]]
lens = [len(row[row > wcutoff]) for row in
self.mcweights[burnin_ind::self.gskip]]
lmode = stats.mode(lens).mode

self.cluster(n_init=117, n_components=lmode)
Expand Down Expand Up @@ -320,8 +323,8 @@ def result_plot(self, remove_noise=False, **kwargs):
mixture_and_plot(self, remove_noise=remove_noise, **kwargs)

def _sample_indicator(self):
indicator = np.zeros(((self.niter+1)//self.g, self.times.shape[0]),
dtype=np.uint8)
indicator = np.zeros(((self.niter+1)//(self.g*self.gskip),
self.times.shape[0]), dtype=np.uint8)
burnin_ind = self.burnin//self.g
for i, (w, r) in enumerate(zip(self.mcweights, self.mcrates)):
# compute probabilities
Expand All @@ -332,7 +335,7 @@ def _sample_indicator(self):
s = np.argmax(rng.multinomial(1, z), axis=1)
indicator[i] = s
setattr(self, 'indicator', indicator)
return indicator[burnin_ind:]
return indicator[burnin_ind::self.gskip]

def save(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion basicrta/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,8 @@ def mixture_and_plot(gibbs, scale=2, sparse=1, remove_noise=False, wlim=None,
else:
wmin, wmax = wcutoff, 2

weights, rates = gibbs.mcweights[burnin_ind:], gibbs.mcrates[burnin_ind:]
weights = gibbs.mcweights[burnin_ind::gibbs.gskip]
rates = gibbs.mcrates[burnin_ind::gibbs.gskip]
lens = np.array([len(row[row > wcutoff]) for row in weights])
lmode = stats.mode(lens).mode
train_param = lmode
Expand Down

0 comments on commit a093ed4

Please sign in to comment.