diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3a3dc22a..063f0c0a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10", "3.11", "3.12"] os: ["ubuntu-latest"] include: - python-version: "3.9" @@ -49,6 +49,30 @@ jobs: COVERALLS_PARALLEL: true COVERALLS_FLAG_NAME: ${{ matrix.python-version }}-${{ matrix.os }} + leading_edge: + runs-on: ${{ matrix.os }} + strategy: + matrix: + python-version: ["3.12"] + os: ["ubuntu-latest"] + + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install -U pip + python -m pip install pip install pytest "numpy>=2.0.0rc1" + python -m pip install -e. + - name: Run tests + run: pytest + coverage: needs: tests runs-on: ubuntu-latest diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index 3212099d..cfef4a57 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -489,10 +489,19 @@ def compute_log_prob(self, coords): results = list(map_func(self.log_prob_fn, p)) try: - log_prob = np.array([float(l[0]) for l in results]) - blob = [l[1:] for l in results] + # perhaps log_prob_fn returns blobs? + + # deal with the blobs first + # if l does not have a len attribute (i.e. not a sequence, no blob) + # then a TypeError is raised. However, no error will be raised if + # l is a length-1 array, np.array([1.234]). In that case blob + # will become an empty list. + blob = [l[1:] for l in results if len(l) > 1] + if not len(blob): + raise IndexError + log_prob = np.array([_scalar(l[0]) for l in results]) except (IndexError, TypeError): - log_prob = np.array([float(l) for l in results]) + log_prob = np.array([_scalar(l) for l in results]) blob = None else: # Get the blobs dtype @@ -502,7 +511,7 @@ def compute_log_prob(self, coords): try: with warnings.catch_warnings(record=True): warnings.simplefilter( - "error", np.VisibleDeprecationWarning + "error", np.exceptions.VisibleDeprecationWarning ) try: dt = np.atleast_1d(blob[0]).dtype @@ -682,3 +691,16 @@ def ndarray_to_list_of_dicts( list of dictionaries of parameters """ return [{key: xi[val] for key, val in key_map.items()} for xi in x] + + +def _scalar(fx): + # Make sure a value is a true scalar + # 1.0, np.float64(1.0), np.array([1.0]), np.array(1.0) + if not np.isscalar(fx): + try: + fx = np.asarray(fx).item() + except (TypeError, ValueError) as e: + raise ValueError("log_prob_fn should return scalar") from e + return float(fx) + else: + return float(fx) diff --git a/src/emcee/tests/unit/test_ensemble.py b/src/emcee/tests/unit/test_ensemble.py index f0568b27..e7981c89 100644 --- a/src/emcee/tests/unit/test_ensemble.py +++ b/src/emcee/tests/unit/test_ensemble.py @@ -183,3 +183,37 @@ def test_run_mcmc(self): assert results.coords.shape == (n_walkers, len(self.names)) chain = sampler.chain assert chain.shape == (n_walkers, n_steps, len(self.names)) + + +class TestLnProbFn(TestCase): + # checks that the log_prob_fn can deal with a variety of 'scalar-likes' + def lnpdf(self, x): + v = np.log(np.sqrt(np.pi) * np.exp(-((x / 2.0) ** 2))) + v = float(v[0]) + assert np.isscalar(v) + return v + + def lnpdf_arr1(self, x): + v = self.lnpdf(x) + return np.array([v]) + + def lnpdf_float64(self, x): + v = self.lnpdf(x) + return np.float64(v) + + def lnpdf_arr0D(self, x): + v = self.lnpdf(x) + return np.array(v) + + def test_deal_with_scalar_likes(self): + rng = np.random.default_rng() + fns = [ + self.lnpdf, + self.lnpdf_arr1, + self.lnpdf_float64, + self.lnpdf_arr0D, + ] + for fn in fns: + init = rng.random((50, 1)) + sampler = EnsembleSampler(50, 1, fn) + _ = sampler.run_mcmc(initial_state=init, nsteps=20) diff --git a/tox.ini b/tox.ini index d759fa63..c5661748 100644 --- a/tox.ini +++ b/tox.ini @@ -1,12 +1,12 @@ [tox] -envlist = py{37,38,39,310}{,-extras},lint +envlist = py{39,310,311,312}{,-extras},lint [gh-actions] python = - 3.7: py37 - 3.8: py38 - 3.9: py39-extras + 3.9: py39 3.10: py310 + 3.11: py311-extras + 3.12: py312 [testenv] deps = coverage[toml]