diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f02afdc623..764c44a2d2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -108,6 +108,7 @@ jobs: - | tests/backends/test_mcbackend.py + tests/backends/test_zarr.py tests/distributions/test_truncated.py tests/logprob/test_abstract.py tests/logprob/test_basic.py @@ -284,6 +285,7 @@ jobs: - | tests/backends/test_arviz.py + tests/backends/test_zarr.py tests/variational/test_updates.py fail-fast: false runs-on: ${{ matrix.os }} diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py new file mode 100644 index 0000000000..78172d3c88 --- /dev/null +++ b/pymc/backends/zarr.py @@ -0,0 +1,540 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any + +import arviz as az +import numcodecs +import numpy as np +import xarray as xr +import zarr + +from arviz.data.base import make_attrs +from arviz.data.inference_data import WARMUP_TAG +from numcodecs.abc import Codec +from pytensor.tensor.variable import TensorVariable +from zarr.storage import BaseStore, default_compressor +from zarr.sync import Synchronizer + +import pymc + +from pymc.backends.arviz import ( + coords_and_dims_for_inferencedata, + find_constants, + find_observations, +) +from pymc.backends.base import BaseTrace +from pymc.blocking import StatDtype, StatShape +from pymc.model.core import Model, modelcontext +from pymc.step_methods.compound import ( + BlockedStep, + CompoundStep, + StatsBijection, + get_stats_dtypes_shapes_from_steps, +) +from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name + + +class ZarrChain(BaseTrace): + def __init__( + self, + store: BaseStore | MutableMapping, + stats_bijection: StatsBijection, + synchronizer: Synchronizer | None = None, + model: Model | None = None, + vars: Sequence[TensorVariable] | None = None, + test_point: Sequence[dict[str, np.ndarray]] | None = None, + ): + super().__init__(name="zarr", model=model, vars=vars, test_point=test_point) + self.unconstrained_variables = { + var.name for var in self.vars if is_transformed_name(var.name) + } + self.draw_idx = 0 + self._posterior = zarr.open_group( + store, synchronizer=synchronizer, path="posterior", mode="a" + ) + if self.unconstrained_variables: + self._unconstrained_posterior = zarr.open_group( + store, synchronizer=synchronizer, path="unconstrained_posterior", mode="a" + ) + self._sample_stats = zarr.open_group( + store, synchronizer=synchronizer, path="sample_stats", mode="a" + ) + self._sampling_state = zarr.open_group( + store, synchronizer=synchronizer, path="_sampling_state", mode="a" + ) + self.stats_bijection = stats_bijection + + def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override] + self.chain = chain + + def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): + chain = self.chain + draw_idx = self.draw_idx + unconstrained_variables = self.unconstrained_variables + for var_name, var_value in zip(self.varnames, self.fn(draw)): + if var_name in unconstrained_variables: + self._unconstrained_posterior[var_name].set_orthogonal_selection( + (chain, draw_idx), + var_value, + ) + else: + self._posterior[var_name].set_orthogonal_selection( + (chain, draw_idx), + var_value, + ) + for var_name, var_value in self.stats_bijection.map(stats).items(): + self._sample_stats[var_name].set_orthogonal_selection( + (chain, draw_idx), + var_value, + ) + self.draw_idx += 1 + + def record_sampling_state(self, step): + self._sampling_state.sampling_state.set_coordinate_selection( + self.chain, np.array([step.sampling_state], dtype="object") + ) + self._sampling_state.draw_idx.set_coordinate_selection(self.chain, self.draw_idx) + + +FILL_VALUE_TYPE = float | int | bool | str | np.datetime64 | np.timedelta64 | None +DEFAULT_FILL_VALUES: dict[Any, FILL_VALUE_TYPE] = { + np.floating: np.nan, + np.integer: 0, + np.bool_: False, + np.str_: "", + np.datetime64: np.datetime64(0, "Y"), + np.timedelta64: np.timedelta64(0, "Y"), +} + + +def get_initial_fill_value_and_codec( + dtype: Any, +) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]: + _dtype = np.dtype(dtype) + fill_value: FILL_VALUE_TYPE = None + codec = None + try: + fill_value = DEFAULT_FILL_VALUES[_dtype] + except KeyError: + for key in DEFAULT_FILL_VALUES: + if np.issubdtype(_dtype, key): + fill_value = DEFAULT_FILL_VALUES[key] + break + else: + codec = numcodecs.Pickle() + return fill_value, _dtype, codec + + +class ZarrTrace: + def __init__( + self, + store: BaseStore | MutableMapping | None = None, + synchronizer: Synchronizer | None = None, + compressor: Codec | None | _UnsetType = UNSET, + draws_per_chunk: int = 1, + include_transformed: bool = False, + ): + self.synchronizer = synchronizer + if compressor is UNSET: + compressor = default_compressor + self.compressor = compressor + self.root = zarr.group( + store=store, + overwrite=True, + synchronizer=synchronizer, + ) + + self.draws_per_chunk = int(draws_per_chunk) + assert self.draws_per_chunk >= 1 + + self.include_transformed = include_transformed + + self._is_base_setup = False + + @property + def posterior(self) -> zarr.Group: + return self.root.posterior + + @property + def unconstrained_posterior(self) -> zarr.Group: + return self.root.unconstrained_posterior + + @property + def sample_stats(self) -> zarr.Group: + return self.root.sample_stats + + @property + def constant_data(self) -> zarr.Group: + return self.root.constant_data + + @property + def observed_data(self) -> zarr.Group: + return self.root.observed_data + + @property + def _sampling_state(self) -> zarr.Group: + return self.root._sampling_state + + def init_trace( + self, + chains: int, + draws: int, + tune: int, + step: BlockedStep | CompoundStep, + model: Model | None = None, + vars: Sequence[TensorVariable] | None = None, + test_point: Mapping[str, np.ndarray] | None = None, + ): + if self._is_base_setup: + raise RuntimeError("The ZarrTrace has already been initialized") # pragma: no cover + model = modelcontext(model) + self.model = model + self.coords, self.vars_to_dims = coords_and_dims_for_inferencedata(model) + if vars is None: + vars = model.unobserved_value_vars + + unnamed_vars = {var for var in vars if var.name is None} + assert not unnamed_vars, f"Can't trace unnamed variables: {unnamed_vars}" + self.varnames = get_default_varnames( + [var.name for var in vars], include_transformed=self.include_transformed + ) + self.vars = [var for var in vars if var.name in self.varnames] + + self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore") + + # Get variable shapes. Most backends will need this + # information. + if test_point is None: + test_point = model.initial_point() + var_values = list(zip(self.varnames, self.fn(test_point))) + self.var_dtype_shapes = { + var: (value.dtype, value.shape) + for var, value in var_values + if not is_transformed_name(var) + } + self.unc_var_dtype_shapes = { + var: (value.dtype, value.shape) for var, value in var_values if is_transformed_name(var) + } + + self.create_group( + name="constant_data", + data_dict=find_constants(self.model), + ) + + self.create_group( + name="observed_data", + data_dict=find_observations(self.model), + ) + + # Create the posterior that includes warmup draws + self.init_group_with_empty( + group=self.root.create_group(name="posterior", overwrite=True), + var_dtype_and_shape=self.var_dtype_shapes, + chains=chains, + draws=tune + draws, + ) + + # Create the unconstrained posterior group that includes warmup draws + if self.include_transformed and self.unc_var_dtype_shapes: + self.init_group_with_empty( + group=self.root.create_group(name="unconstrained_posterior", overwrite=True), + var_dtype_and_shape=self.unc_var_dtype_shapes, + chains=chains, + draws=tune + draws, + ) + + # Create the sample stats that include warmup draws + stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps( + [step] if isinstance(step, BlockedStep) else step.methods + ) + self.init_group_with_empty( + group=self.root.create_group(name="sample_stats", overwrite=True), + var_dtype_and_shape=stats_dtypes_shapes, + chains=chains, + draws=tune + draws, + ) + + self.init_sampling_state_group(tune=tune, chains=chains) + + self.straces = [ + ZarrChain( + store=self.root.store, + synchronizer=self.synchronizer, + model=self.model, + vars=self.vars, + test_point=test_point, + stats_bijection=StatsBijection(step.stats_dtypes), + ) + ] + for chain, strace in enumerate(self.straces): + strace.setup(draws=draws, chain=chain, sampler_vars=None) + + def split_warmup_groups(self): + self.split_warmup("posterior", error_if_already_split=False) + self.split_warmup("sample_stats", error_if_already_split=False) + try: + self.split_warmup("unconstrained_posterior", error_if_already_split=False) + except KeyError: + pass + + def consolidate(self): + if not isinstance(self.root.store, zarr.storage.ConsolidatedMetadataStore): + self.split_warmup_groups() + self.root = zarr.consolidate_metadata(self.root.store) + + @property + def tuning_steps(self): + try: + return int(self._sampling_state.tuning_steps.get_basic_selection()) + except AttributeError: # pragma: no cover + raise ValueError( + "ZarrTrace has not been initialized and there is no tuning step information available" + ) + + @property + def sampling_time(self): + try: + return float(self._sampling_state.sampling_time.get_basic_selection()) + except AttributeError: # pragma: no cover + raise ValueError( + "ZarrTrace has not been initialized and there is no sampling time information available" + ) + + def init_sampling_state_group(self, tune: int, chains: int): + state = self.root.create_group(name="_sampling_state", overwrite=True) + sampling_state = state.empty( + name="sampling_state", + overwrite=True, + shape=(chains,), + chunks=(1,), + dtype="object", + object_codec=numcodecs.Pickle(), + compressor=self.compressor, + ) + sampling_state.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + draw_idx = state.array( + name="draw_idx", + overwrite=True, + data=np.zeros(chains, dtype="int"), + chunks=(1,), + dtype="int", + fill_value=-1, + compressor=self.compressor, + ) + draw_idx.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + + state.array( + name="tuning_steps", + data=tune, + overwrite=True, + dtype="int", + fill_value=0, + compressor=self.compressor, + ) + state.array( + name="sampling_time", + data=0.0, + dtype="float", + fill_value=0.0, + compressor=self.compressor, + ) + state.array( + name="sampling_start_time", + data=0.0, + dtype="float", + fill_value=0.0, + compressor=self.compressor, + ) + + chain = state.array( + name="chain", + data=np.arange(chains), + compressor=self.compressor, + ) + + chain.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + + state.empty( + name="global_warnings", + dtype="object", + object_codec=numcodecs.Pickle(), + fill_value=None, + shape=(0,), + ) + + def init_group_with_empty( + self, + group: zarr.Group, + var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]], + chains: int, + draws: int, + ) -> zarr.Group: + group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)} + for name, (_dtype, shape) in var_dtype_and_shape.items(): + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype) + shape = shape or () + array = group.full( + name=name, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + shape=(chains, draws, *shape), + chunks=(1, self.draws_per_chunk, *shape), + compressor=self.compressor, + ) + try: + dims = self.vars_to_dims[name] + for dim in dims: + group_coords[dim] = self.coords[dim] + except KeyError: + dims = [] + for i, shape_i in enumerate(shape): + dim = f"{name}_dim_{i}" + dims.append(dim) + group_coords[dim] = np.arange(shape_i, dtype="int") + dims = ("chain", "draw", *dims) + array.attrs.update({"_ARRAY_DIMENSIONS": dims}) + for dim, coord in group_coords.items(): + array = group.array( + name=dim, + data=coord, + fill_value=None, + compressor=self.compressor, + ) + array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) + return group + + def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None: + group: zarr.Group | None = None + if data_dict: + group_coords = {} + group = self.root.create_group(name=name, overwrite=True) + for var_name, var_value in data_dict.items(): + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(var_value.dtype) + array = group.array( + name=var_name, + data=var_value, + fill_value=fill_value, + dtype=dtype, + object_codec=object_codec, + compressor=self.compressor, + ) + try: + dims = self.vars_to_dims[var_name] + for dim in dims: + group_coords[dim] = self.coords[dim] + except KeyError: + dims = [] + for i in range(var_value.ndim): + dim = f"{var_name}_dim_{i}" + dims.append(dim) + group_coords[dim] = np.arange(var_value.shape[i], dtype="int") + array.attrs.update({"_ARRAY_DIMENSIONS": dims}) + for dim, coord in group_coords.items(): + array = group.array( + name=dim, + data=coord, + fill_value=None, + compressor=self.compressor, + ) + array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) + return group + + def split_warmup(self, group_name, error_if_already_split=True): + if error_if_already_split and f"{WARMUP_TAG}{group_name}" in { + group_name for group_name, _ in self.root.groups() + }: + raise RuntimeError(f"Warmup data for {group_name} has already been split") + posterior_group = self.root[group_name] + tune = self.tuning_steps + warmup_group = self.root.create_group(f"{WARMUP_TAG}{group_name}", overwrite=True) + if tune == 0: + try: + self.root.pop(f"{WARMUP_TAG}{group_name}") + except KeyError: + pass + return + for name, array in posterior_group.arrays(): + array_attrs = array.attrs.asdict() + if name == "draw": + warmup_array = warmup_group.array( + name="draw", + data=np.arange(tune), + dtype="int", + compressor=self.compressor, + ) + posterior_array = posterior_group.array( + name=name, + data=np.arange(len(array) - tune), + dtype="int", + overwrite=True, + compressor=self.compressor, + ) + posterior_array.attrs.update(array_attrs) + else: + dims = array.attrs["_ARRAY_DIMENSIONS"] + if len(dims) >= 2 and dims[:2] == ["chain", "draw"]: + must_overwrite_posterior = True + warmup_idx = (slice(None), slice(None, tune, None)) + posterior_idx = (slice(None), slice(tune, None, None)) + else: + must_overwrite_posterior = False + warmup_idx = slice(None) + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(array.dtype) + warmup_array = warmup_group.array( + name=name, + data=array[warmup_idx], + chunks=array.chunks, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + compressor=self.compressor, + ) + if must_overwrite_posterior: + posterior_array = posterior_group.array( + name=name, + data=array[posterior_idx], + chunks=array.chunks, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + overwrite=True, + compressor=self.compressor, + ) + posterior_array.attrs.update(array_attrs) + warmup_array.attrs.update(array_attrs) + + def to_inferencedata(self, save_warmup=False) -> az.InferenceData: + self.consolidate() + # The ConsolidatedMetadataStore looks like an empty store from xarray's point of view + # we need to actually grab the underlying store so that xarray doesn't produce completely + # empty arrays + store = self.root.store.store + groups = {} + try: + global_attrs = { + "tuning_steps": self.tuning_steps, + "sampling_time": self.sampling_time, + } + except AttributeError: + global_attrs = {} # pragma: no cover + for name, _ in self.root.groups(): + if name.startswith("_") or (not save_warmup and name.startswith(WARMUP_TAG)): + continue + data = xr.open_zarr(store, group=name, mask_and_scale=False) + attrs = {**data.attrs, **global_attrs} + data.attrs = make_attrs(attrs=attrs, library=pymc) + groups[name] = data.load() if az.rcParams["data.load"] == "eager" else data + return az.InferenceData(**groups) diff --git a/requirements.txt b/requirements.txt index b59ca29127..0a2c224706 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ rich>=13.7.1 scipy>=1.4.1 threadpoolctl>=3.1.0,<4.0.0 typing-extensions>=3.7.4 +zarr>=2.5.0,<3 diff --git a/tests/backends/test_zarr.py b/tests/backends/test_zarr.py new file mode 100644 index 0000000000..34a4bfffb1 --- /dev/null +++ b/tests/backends/test_zarr.py @@ -0,0 +1,340 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools + +from dataclasses import asdict + +import numpy as np +import pytest +import zarr + +import pymc as pm + +from pymc.backends.zarr import ZarrTrace +from pymc.stats.convergence import SamplerWarning +from pymc.step_methods import NUTS, CompoundStep, Metropolis +from pymc.step_methods.state import equal_dataclass_values +from tests.helpers import equal_sampling_states + + +@pytest.fixture(scope="module") +def model(): + time_int = np.array([np.timedelta64(np.timedelta64(i, "h"), "ns") for i in range(25)]) + coords = { + "dim_int": range(3), + "dim_str": ["A", "B"], + "dim_time": np.datetime64("2024-10-16") + time_int, + "dim_interval": time_int, + } + rng = np.random.default_rng(42) + with pm.Model(coords=coords) as model: + data1 = pm.Data("data1", np.ones(3, dtype="bool"), dims=["dim_int"]) + data2 = pm.Data("data2", np.ones(3, dtype="bool")) + time = pm.Data("time", time_int / np.timedelta64(1, "h"), dims="dim_time") + + a = pm.Normal("a", shape=(len(coords["dim_int"]), len(coords["dim_str"]))) + b = pm.Normal("b", dims=["dim_int", "dim_str"]) + c = pm.Deterministic("c", a + b, dims=["dim_int", "dim_str"]) + + d = pm.LogNormal("d", dims="dim_time") + e = pm.Deterministic("e", (time + d)[:, None] + c[0], dims=["dim_interval", "dim_str"]) + + obs = pm.Normal( + "obs", + mu=e, + observed=rng.normal(size=(len(coords["dim_time"]), len(coords["dim_str"]))), + dims=["dim_time", "dim_str"], + ) + + return model + + +@pytest.fixture(params=[True, False]) +def include_transformed(request): + return request.param + + +@pytest.fixture(params=["single_step", "compound_step"]) +def model_step(request, model): + rng = np.random.default_rng(42) + with model: + if request.param == "single_step": + step = NUTS(rng=rng) + else: + rngs = rng.spawn(2) + step = CompoundStep( + [ + Metropolis(vars=model["a"], rng=rngs[0]), + NUTS(vars=[rv for rv in model.value_vars if rv.name != "a"], rng=rngs[1]), + ] + ) + return step + + +def test_record(model, model_step, include_transformed): + store = zarr.MemoryStore() + trace = ZarrTrace(store=store, include_transformed=include_transformed) + draws = 5 + tune = 5 + trace.init_trace(chains=1, draws=draws, tune=tune, model=model, step=model_step) + + # Assert that init was successful + expected_groups = { + "_sampling_state", + "sample_stats", + "posterior", + "constant_data", + "observed_data", + } + if include_transformed: + expected_groups.add("unconstrained_posterior") + assert {group_name for group_name, _ in trace.root.groups()} == expected_groups + + # Record samples from the ZarrChain + manually_collected_warmup_draws = [] + manually_collected_warmup_stats = [] + manually_collected_draws = [] + manually_collected_stats = [] + point = model.initial_point() + for draw in range(tune + draws): + tuning = draw < tune + if not tuning: + model_step.stop_tuning() + point, stats = model_step.step(point) + if tuning: + manually_collected_warmup_draws.append(point) + manually_collected_warmup_stats.append(stats) + else: + manually_collected_draws.append(point) + manually_collected_stats.append(stats) + trace.straces[0].record(point, stats) + trace.straces[0].record_sampling_state(model_step) + assert {group_name for group_name, _ in trace.root.groups()} == expected_groups + + # Assert split warmup + trace.split_warmup("posterior") + trace.split_warmup("sample_stats") + expected_groups = { + "_sampling_state", + "sample_stats", + "posterior", + "warmup_sample_stats", + "warmup_posterior", + "constant_data", + "observed_data", + } + if include_transformed: + trace.split_warmup("unconstrained_posterior") + expected_groups.add("unconstrained_posterior") + expected_groups.add("warmup_unconstrained_posterior") + assert {group_name for group_name, _ in trace.root.groups()} == expected_groups + # trace.consolidate() + + # Assert observed data is correct + assert set(dict(trace.observed_data.arrays())) == {"obs", "dim_time", "dim_str"} + assert list(trace.observed_data.obs.attrs["_ARRAY_DIMENSIONS"]) == ["dim_time", "dim_str"] + np.testing.assert_array_equal(trace.observed_data.dim_time[:], model.coords["dim_time"]) + np.testing.assert_array_equal(trace.observed_data.dim_str[:], model.coords["dim_str"]) + + # Assert constant data is correct + assert set(dict(trace.constant_data.arrays())) == { + "data1", + "data2", + "data2_dim_0", + "time", + "dim_time", + "dim_int", + } + assert list(trace.constant_data.data1.attrs["_ARRAY_DIMENSIONS"]) == ["dim_int"] + assert list(trace.constant_data.data2.attrs["_ARRAY_DIMENSIONS"]) == ["data2_dim_0"] + assert list(trace.constant_data.time.attrs["_ARRAY_DIMENSIONS"]) == ["dim_time"] + np.testing.assert_array_equal(trace.constant_data.dim_time[:], model.coords["dim_time"]) + np.testing.assert_array_equal(trace.constant_data.dim_int[:], model.coords["dim_int"]) + + # Assert unconstrained posterior has correct shapes + assert {rv.name for rv in model.free_RVs + model.deterministics} <= set( + dict(trace.posterior.arrays()) + ) + if include_transformed: + assert {"d_log__", "chain", "draw", "d_log___dim_0"} == set( + dict(trace.unconstrained_posterior.arrays()) + ) + assert list(trace.unconstrained_posterior.d_log__.attrs["_ARRAY_DIMENSIONS"]) == [ + "chain", + "draw", + "d_log___dim_0", + ] + np.testing.assert_array_equal(trace.unconstrained_posterior.chain, np.arange(1)) + np.testing.assert_array_equal(trace.unconstrained_posterior.draw, np.arange(draws)) + np.testing.assert_array_equal( + trace.unconstrained_posterior.d_log___dim_0, np.arange(len(model.coords["dim_time"])) + ) + + # Assert posterior has correct shape + posterior_dims = set() + for rv_name in [rv.name for rv in model.free_RVs + model.deterministics]: + if rv_name == "a": + expected_dims = ["a_dim_0", "a_dim_1"] + else: + expected_dims = model.named_vars_to_dims[rv_name] + posterior_dims |= set(expected_dims) + assert list(trace.posterior[rv_name].attrs["_ARRAY_DIMENSIONS"]) == [ + "chain", + "draw", + *expected_dims, + ] + for posterior_dim in posterior_dims: + try: + model_coord = model.coords[posterior_dim] + except KeyError: + model_coord = { + "a_dim_0": np.arange(len(model.coords["dim_int"])), + "a_dim_1": np.arange(len(model.coords["dim_str"])), + "chain": np.arange(1), + "draw": np.arange(draws), + }[posterior_dim] + np.testing.assert_array_equal(trace.posterior[posterior_dim][:], model_coord) + + # Assert sample stats have correct shape + stats_bijection = trace.straces[0].stats_bijection + for draw_idx, (draw, stat) in enumerate( + zip(manually_collected_draws, manually_collected_stats) + ): + stat = stats_bijection.map(stat) + for var, value in draw.items(): + if var in trace.posterior.arrays(): + assert np.array_equal(trace.posterior[var][0, draw_idx], value) + for var, value in stat.items(): + sample_stats = trace.root["sample_stats"] + stat_val = sample_stats[var][0, draw_idx] + if not isinstance(stat_val, SamplerWarning): + unequal_stats = stat_val != value + else: + unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value)) + if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)): + raise AssertionError(f"{stat_val} != {value}") + + # Assert manually collected warmup samples match + for draw_idx, (draw, stat) in enumerate( + zip(manually_collected_warmup_draws, manually_collected_warmup_stats) + ): + stat = stats_bijection.map(stat) + for var, value in draw.items(): + if var == "d_log__": + if not include_transformed: + continue + posterior = trace.root["warmup_unconstrained_posterior"] + else: + posterior = trace.root["warmup_posterior"] + if var in posterior.arrays(): + assert np.array_equal(posterior[var][0, draw_idx], value) + for var, value in stat.items(): + sample_stats = trace.root["warmup_sample_stats"] + stat_val = sample_stats[var][0, draw_idx] + if not isinstance(stat_val, SamplerWarning): + unequal_stats = stat_val != value + else: + unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value)) + if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)): + raise AssertionError(f"{stat_val} != {value}") + + # Assert manually collected posterior samples match + for draw_idx, (draw, stat) in enumerate( + zip(manually_collected_draws, manually_collected_stats) + ): + stat = stats_bijection.map(stat) + for var, value in draw.items(): + if var == "d_log__": + if not include_transformed: + continue + posterior = trace.root["unconstrained_posterior"] + else: + posterior = trace.root["posterior"] + if var in posterior.arrays(): + assert np.array_equal(posterior[var][0, draw_idx], value) + for var, value in stat.items(): + sample_stats = trace.root["sample_stats"] + stat_val = sample_stats[var][0, draw_idx] + if not isinstance(stat_val, SamplerWarning): + unequal_stats = stat_val != value + else: + unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value)) + if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)): + raise AssertionError(f"{stat_val} != {value}") + + # Assert sampling_state is correct + assert list(trace._sampling_state.draw_idx[:]) == [draws + tune] + assert equal_sampling_states( + trace._sampling_state.sampling_state[0], + model_step.sampling_state, + ) + + # Assert to inference data returns the expected groups + idata = trace.to_inferencedata() + expected_groups = { + "posterior", + "constant_data", + "observed_data", + "sample_stats", + "warmup_posterior", + "warmup_sample_stats", + } + if include_transformed: + expected_groups.add("unconstrained_posterior") + expected_groups.add("warmup_unconstrained_posterior") + assert set(idata.groups()) == expected_groups + for group in idata.groups(): + for name, value in itertools.chain( + idata[group].data_vars.items(), idata[group].coords.items() + ): + try: + array = getattr(trace, group)[name][:] + except AttributeError: + array = trace.root[group][name][:] + if "sample_stats" in group and "warning" in name: + continue + np.testing.assert_array_equal(array, value) + + +@pytest.mark.parametrize("tune", [0, 5, 10]) +def test_split_warmup(tune, model, model_step, include_transformed): + store = zarr.MemoryStore() + trace = ZarrTrace(store=store, include_transformed=include_transformed) + draws = 10 - tune + trace.init_trace(chains=1, draws=draws, tune=tune, model=model, step=model_step) + + trace.split_warmup("posterior") + trace.split_warmup("sample_stats") + assert len(trace.root.posterior.draw) == draws + assert len(trace.root.sample_stats.draw) == draws + if tune == 0: + with pytest.raises(KeyError): + trace.root["warmup_posterior"] + else: + assert len(trace.root["warmup_posterior"].draw) == tune + assert len(trace.root["warmup_sample_stats"].draw) == tune + + with pytest.raises(RuntimeError): + trace.split_warmup("posterior") + + for var_name, posterior_array in trace.posterior.arrays(): + dims = posterior_array.attrs["_ARRAY_DIMENSIONS"] + if len(dims) >= 2 and dims[1] == "draw": + assert posterior_array.shape[1] == draws + assert trace.root["warmup_posterior"][var_name].shape[1] == tune + for var_name, sample_stats_array in trace.sample_stats.arrays(): + dims = sample_stats_array.attrs["_ARRAY_DIMENSIONS"] + if len(dims) >= 2 and dims[1] == "draw": + assert sample_stats_array.shape[1] == draws + assert trace.root["warmup_sample_stats"][var_name].shape[1] == tune