Skip to content

Commit

Permalink
MAINT: deal with log_prob_fn returning a variety of scalar-likes
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfaff committed Apr 5, 2024
1 parent 912387a commit 2c9923c
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
28 changes: 25 additions & 3 deletions src/emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,19 @@ def compute_log_prob(self, coords):
results = list(map_func(self.log_prob_fn, p))

try:
log_prob = np.array([float(l[0]) for l in results])
blob = [l[1:] for l in results]
# perhaps log_prob_fn returns blobs?

# deal with the blobs first
# if l does not have a len attribute (i.e. not a sequence, no blob)
# then a TypeError is raised. However, no error will be raised if
# l is a length-1 array, np.array([1.234]). In that case blob
# will become an empty list.
blob = [l[1:] for l in results if len(l) > 1]
if not len(blob):
raise IndexError
log_prob = np.array([_scalar(l[0]) for l in results])
except (IndexError, TypeError):
log_prob = np.array([float(l) for l in results])
log_prob = np.array([_scalar(l) for l in results])
blob = None
else:
# Get the blobs dtype
Expand Down Expand Up @@ -682,3 +691,16 @@ def ndarray_to_list_of_dicts(
list of dictionaries of parameters
"""
return [{key: xi[val] for key, val in key_map.items()} for xi in x]


def _scalar(fx):
# Make sure a value is a true scalar
# 1.0, np.float64(1.0), np.array([1.0]), np.array(1.0)
if not np.isscalar(fx):
try:
fx = np.asarray(fx).item()
except (TypeError, ValueError) as e:
raise ValueError("log_prob_fn should return scalar") from e
return float(fx)
else:
return float(fx)
31 changes: 31 additions & 0 deletions src/emcee/tests/unit/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,34 @@ def test_run_mcmc(self):
assert results.coords.shape == (n_walkers, len(self.names))
chain = sampler.chain
assert chain.shape == (n_walkers, n_steps, len(self.names))


class TestLnProbFn(TestCase):
# checks that the log_prob_fn can deal with a variety of 'scalar-likes'
def lnpdf(self, x):
v = np.log(np.sqrt(np.pi) * np.exp(-((x / 2.0) ** 2)))
v = float(v[0])
assert np.isscalar(v)
return v

def lnpdf_arr1(self, x):
v = self.lnpdf(x)
return np.array([v])

def lnpdf_float64(self, x):
v = self.lnpdf(x)
return np.float64(v)

def lnpdf_arr0D(self, x):
v = self.lnpdf(x)
return np.array(v)

def test_deal_with_scalar_likes(self):
rng = np.random.default_rng()
fns = [
self.lnpdf, self.lnpdf_arr1, self.lnpdf_float64, self.lnpdf_arr0D
]
for fn in fns:
init = rng.random((50, 1))
sampler = EnsembleSampler(50, 1, fn)
_ = sampler.run_mcmc(initial_state=init, nsteps=20)

0 comments on commit 2c9923c

Please sign in to comment.