Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: use Generator instead of RandomState #528

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/emcee/backends/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import division, print_function

import json
import os
from tempfile import NamedTemporaryFile

Expand All @@ -19,6 +20,13 @@
h5py = None


class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)


def does_hdf5_support_longdouble():
if h5py is None:
return False
Expand Down Expand Up @@ -193,12 +201,11 @@ def accepted(self):
@property
def random_state(self):
with self.open() as f:
elements = [
v
for k, v in sorted(f[self.name].attrs.items())
if k.startswith("random_state_")
]
return elements if len(elements) else None
try:
dct = json.loads(f[self.name].attrs["random_state"])
except KeyError:
return None
return dct

def grow(self, ngrow, blobs):
"""Expand the storage space by some number of samples
Expand Down Expand Up @@ -261,8 +268,9 @@ def save_step(self, state, accepted):
g["blobs"][iteration, :] = state.blobs
g["accepted"][:] += accepted

for i, v in enumerate(state.random_state):
g.attrs["random_state_{0}".format(i)] = v
g.attrs["random_state"] = json.dumps(
state.random_state, cls=NumpyEncoder
)

g.attrs["iteration"] = iteration + 1

Expand Down
50 changes: 36 additions & 14 deletions src/emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class EnsembleSampler(object):
names of individual parameters or groups of parameters. If
specified, the ``log_prob_fn`` will recieve a dictionary of
parameters, rather than a ``np.ndarray``.
rng (Optional):
int, :class:`np.random.Generator`, used for reproducibility.

"""

Expand All @@ -89,6 +91,7 @@ def __init__(
vectorize=False,
blobs_dtype=None,
parameter_names: Optional[Union[Dict[str, int], List[str]]] = None,
rng=None,
# Deprecated...
a=None,
postargs=None,
Expand Down Expand Up @@ -136,11 +139,14 @@ def __init__(
self.nwalkers = nwalkers
self.backend = Backend() if backend is None else backend

# This is a random number generator that we can easily set the state
# of
self._random = np.random.default_rng(rng)

# Deal with re-used backends
if not self.backend.initialized:
self._previous_state = None
self.reset()
state = np.random.get_state()
else:
# Check the backend shape
if self.backend.shape != (self.nwalkers, self.ndim):
Expand All @@ -153,19 +159,14 @@ def __init__(

# Get the last random state
state = self.backend.random_state
if state is None:
state = np.random.get_state()
if state is not None:
self.random_state = state

# Grab the last step so that we can restart
it = self.backend.iteration
if it > 0:
self._previous_state = self.get_last_sample()

# This is a random number generator that we can easily set the state
# of without affecting the numpy-wide generator
self._random = np.random.mtrand.RandomState()
self._random.set_state(state)

# Do a little bit of _magic_ to make the likelihood call with
# ``args`` and ``kwargs`` pickleable.
self.log_prob_fn = _FunctionWrapper(log_prob_fn, args, kwargs)
Expand Down Expand Up @@ -216,14 +217,25 @@ def __init__(
@property
def random_state(self):
"""
The state of the internal random number generator. In practice, it's
the result of calling ``get_state()`` on a
``numpy.random.mtrand.RandomState`` object. You can try to set this
The state of the internal random number generator. You can try to set this
property but be warned that if you do this and it fails, it will do
so silently.

"""
return self._random.get_state()

def rng_dict(rng):
bg_state = rng.bit_generator.state
ss = rng.bit_generator.seed_seq
ss_dict = dict(
entropy=ss.entropy,
spawn_key=ss.spawn_key,
pool_size=ss.pool_size,
n_children_spawned=ss.n_children_spawned,
)
return dict(bg_state=bg_state, seed_seq=ss_dict)

return rng_dict(self._random)
# return self._random.bit_generator.state

@random_state.setter # NOQA
def random_state(self, state):
Expand All @@ -232,8 +244,18 @@ def random_state(self, state):
if it doesn't work. Don't say I didn't warn you...

"""

def _rng_fromdict(d):
bg_state = d["bg_state"]
ss = np.random.SeedSequence(**d["seed_seq"])
bg = getattr(np.random, bg_state["bit_generator"])(ss)
bg.state = bg_state
rng = np.random.Generator(bg)
return rng

try:
self._random.set_state(state)
self._random = _rng_fromdict(state)
# self._random.bit_generator = state
except:
pass

Expand Down Expand Up @@ -325,7 +347,7 @@ def sample(
# Try to set the initial value of the random number generator. This
# fails silently if it doesn't work but that's what we want because
# we'll just interpret any garbage as letting the generator stay in
# it's current state.
# its current state.
if rstate0 is not None:
deprecation_warning(
"The 'rstate0' argument is deprecated, use a 'State' "
Expand Down
4 changes: 3 additions & 1 deletion src/emcee/moves/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def get_proposal(self, s, c, random):
diffs = np.diff(c[pairs], axis=1).squeeze(axis=1) # (ns, ndim)

# Sample a gamma value for each walker following Nelson et al. (2013)
gamma = self.g0 * (1 + self.sigma * random.randn(ns, 1)) # (ns, 1)
gamma = self.g0 * (
1 + self.sigma * random.standard_normal((ns, 1))
) # (ns, 1)

# In this way, sigma is the standard deviation of the distribution of gamma,
# instead of the standard deviation of the distribution of the proposal as proposed by Ter Braak (2006).
Expand Down
2 changes: 1 addition & 1 deletion src/emcee/moves/de_snooker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_proposal(self, s, c, random):
q = np.empty_like(s)
metropolis = np.empty(Ns, dtype=np.float64)
for i in range(Ns):
w = np.array([c[j][random.randint(Nc[j])] for j in range(3)])
w = np.array([c[j][random.integers(Nc[j])] for j in range(3)])
random.shuffle(w)
z, z1, z2 = w
delta = s[i] - z
Expand Down
10 changes: 7 additions & 3 deletions src/emcee/moves/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ def get_factor(self, rng):
return np.exp(rng.uniform(-self._log_factor, self._log_factor))

def get_updated_vector(self, rng, x0):
return x0 + self.get_factor(rng) * self.scale * rng.randn(*(x0.shape))
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal(
(x0.shape)
)

def __call__(self, x0, rng):
nw, nd = x0.shape
xnew = self.get_updated_vector(rng, x0)
if self.mode == "random":
m = (range(nw), rng.randint(x0.shape[-1], size=nw))
m = (range(nw), rng.integers(x0.shape[-1], size=nw))
elif self.mode == "sequential":
m = (range(nw), self.index % nd + np.zeros(nw, dtype=int))
self.index = (self.index + 1) % nd
Expand All @@ -106,7 +108,9 @@ def __call__(self, x0, rng):

class _diagonal_proposal(_isotropic_proposal):
def get_updated_vector(self, rng, x0):
return x0 + self.get_factor(rng) * self.scale * rng.randn(*(x0.shape))
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal(
(x0.shape)
)


class _proposal(_isotropic_proposal):
Expand Down
2 changes: 1 addition & 1 deletion src/emcee/moves/mh.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def propose(self, model, state):

# Loop over the walkers and update them accordingly.
lnpdiff = new_log_probs - state.log_prob + factors
accepted = np.log(model.random.rand(nwalkers)) < lnpdiff
accepted = np.log(model.random.random(nwalkers)) < lnpdiff

# Update the parameters
new_state = State(q, log_prob=new_log_probs, blobs=new_blobs)
Expand Down
2 changes: 1 addition & 1 deletion src/emcee/moves/red_blue.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def propose(self, model, state):
zip(all_inds[S1], factors, new_log_probs)
):
lnpdiff = f + nlp - state.log_prob[j]
if lnpdiff > np.log(model.random.rand()):
if lnpdiff > np.log(model.random.random()):
accepted[j] = True

new_state = State(q, log_prob=new_log_probs, blobs=new_blobs)
Expand Down
4 changes: 2 additions & 2 deletions src/emcee/moves/stretch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_proposal(self, s, c, random):
c = np.concatenate(c, axis=0)
Ns, Nc = len(s), len(c)
ndim = s.shape[1]
zz = ((self.a - 1.0) * random.rand(Ns) + 1) ** 2.0 / self.a
zz = ((self.a - 1.0) * random.random(Ns) + 1) ** 2.0 / self.a
factors = (ndim - 1.0) * np.log(zz)
rint = random.randint(Nc, size=(Ns,))
rint = random.integers(Nc, size=(Ns,))
return c[rint] - (c[rint] - s) * zz[:, None], factors
33 changes: 9 additions & 24 deletions src/emcee/tests/unit/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,10 @@ def run_sampler(
):
if lp is None:
lp = normal_log_prob_blobs if blobs else normal_log_prob
if seed is not None:
np.random.seed(seed)
coords = np.random.randn(nwalkers, ndim)
rng = np.random.default_rng(seed)
coords = rng.standard_normal((nwalkers, ndim))
sampler = EnsembleSampler(
nwalkers, ndim, lp, backend=backend, blobs_dtype=dtype
nwalkers, ndim, lp, rng=rng, backend=backend, blobs_dtype=dtype
)
sampler.run_mcmc(coords, nsteps, thin_by=thin_by)
return sampler
Expand Down Expand Up @@ -125,10 +124,7 @@ def test_backend(backend, dtype, blobs):
last2 = sampler2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
assert np.allclose(last1.log_prob, last2.log_prob)
assert all(
np.allclose(l1, l2)
for l1, l2 in zip(last1.random_state[1:], last2.random_state[1:])
)
assert last1.random_state["bg_state"] == last2.random_state["bg_state"]
if blobs:
_custom_allclose(last1.blobs, last2.blobs)
else:
Expand All @@ -146,7 +142,6 @@ def test_reload(backend, dtype):

# Test the state
state = backend1.random_state
np.random.set_state(state)

# Load the file using a new backend object.
backend2 = backends.HDFBackend(
Expand All @@ -156,11 +151,7 @@ def test_reload(backend, dtype):
with pytest.raises(RuntimeError):
backend2.reset(32, 3)

assert state[0] == backend2.random_state[0]
assert all(
np.allclose(a, b)
for a, b in zip(state[1:], backend2.random_state[1:])
)
assert state == backend2.random_state

# Check all of the components.
for k in ["chain", "log_prob", "blobs"]:
Expand All @@ -172,10 +163,7 @@ def test_reload(backend, dtype):
last2 = backend2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
assert np.allclose(last1.log_prob, last2.log_prob)
assert all(
np.allclose(l1, l2)
for l1, l2 in zip(last1.random_state[1:], last2.random_state[1:])
)
assert last1.random_state == last2.random_state
_custom_allclose(last1.blobs, last2.blobs)

a = backend1.accepted
Expand All @@ -188,11 +176,11 @@ def test_restart(backend, dtype):
# Run a sampler with the default backend.
b = backends.Backend()
run_sampler(b, dtype=dtype)
sampler1 = run_sampler(b, seed=None, dtype=dtype)
sampler1 = run_sampler(b, seed=2, dtype=dtype)

with backend() as be:
run_sampler(be, dtype=dtype)
sampler2 = run_sampler(be, seed=None, dtype=dtype)
sampler2 = run_sampler(be, seed=2, dtype=dtype)

# Check all of the components.
for k in ["chain", "log_prob", "blobs"]:
Expand All @@ -204,10 +192,7 @@ def test_restart(backend, dtype):
last2 = sampler2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
assert np.allclose(last1.log_prob, last2.log_prob)
assert all(
np.allclose(l1, l2)
for l1, l2 in zip(last1.random_state[1:], last2.random_state[1:])
)
assert last1.random_state["bg_state"] == last2.random_state["bg_state"]
_custom_allclose(last1.blobs, last2.blobs)

a = sampler1.acceptance_fraction
Expand Down
8 changes: 5 additions & 3 deletions src/emcee/tests/unit/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,11 @@ def run_sampler(
progress=False,
store=True,
):
np.random.seed(seed)
coords = np.random.randn(nwalkers, ndim)
sampler = EnsembleSampler(nwalkers, ndim, normal_log_prob, backend=backend)
rng = np.random.default_rng(seed)
coords = rng.standard_normal((nwalkers, ndim))
sampler = EnsembleSampler(
nwalkers, ndim, normal_log_prob, rng=rng, backend=backend
)
sampler.run_mcmc(
coords,
nsteps,
Expand Down
6 changes: 3 additions & 3 deletions src/emcee/tests/unit/test_stretch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def test_live_dangerously(nwalkers=32, nsteps=3000, seed=1234):
warnings.filterwarnings("error")

# Set up the random number generator.
np.random.seed(seed)
rng = np.random.default_rng(seed)
state = State(
np.random.randn(nwalkers, 2 * nwalkers),
rng.standard_normal((nwalkers, 2 * nwalkers)),
log_prob=np.random.randn(nwalkers),
)
model = Model(None, lambda x: (np.zeros(len(x)), None), map, np.random)
model = Model(None, lambda x: (np.zeros(len(x)), None), map, rng)
proposal = moves.StretchMove()

# Test to make sure that the error is thrown if there aren't enough
Expand Down
Loading