Skip to content

Commit

Permalink
Merge pull request #55 from minaskar/dev
Browse files Browse the repository at this point in the history
1.2.6
  • Loading branch information
minaskar authored Sep 20, 2024
2 parents 0adf84b + 15f377a commit 77cbe33
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
4 changes: 4 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ Copyright 2022-2024 Minas Karamanis and contributors.
Changelog
=========

**1.2.6 (20/09/24)**

- Removed unnecessary log-likelihood evaluations during evidence estimation

**1.2.5 (16/09/24)**

- Removed unnecessary log-likelihood evaluations during MCMC sampling.
Expand Down
2 changes: 1 addition & 1 deletion pocomc/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.2.5"
version = "1.2.6"
20 changes: 18 additions & 2 deletions pocomc/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,21 +882,37 @@ def _compute_evidence(self, n=5_000):
dlogz : float
Estimate of the error on the log evidence.
"""
# sample from the flow
with torch.no_grad():
theta_q, logq = self.flow.sample(n)
theta_q = torch_to_numpy(theta_q)
logq = torch_to_numpy(logq)

# reparameterize
x_q, logdetj = self.scaler.inverse(theta_q)
logl, _ = self._log_like(x_q)

# compute log prior
logp = self.log_prior(x_q)

# keep only finite values
x_q = x_q[np.isfinite(logp)]
logdetj = logdetj[np.isfinite(logp)]
logq = logq[np.isfinite(logp)]
logp = logp[np.isfinite(logp)]

# compute log likelihood
logl, _ = self._log_like(x_q)

# compute log weights
logw = logl + logp + logdetj - logq

# compute log evidence
logz = np.logaddexp.reduce(logw) - np.log(len(logw))

# compute error on log evidence
dlogz = np.std([np.logaddexp.reduce(logw[np.random.choice(len(logw), len(logw))]) - np.log(len(logw)) for _ in range(np.maximum(n,1000))])

self.calls += n
self.calls += len(logw)
self.pbar.update_stats(dict(calls=self.calls))

self.logz = logz
Expand Down

0 comments on commit 77cbe33

Please sign in to comment.