Skip to content

Commit

Permalink
Integrate ZarrTrace into pymc.sample
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Oct 23, 2024
1 parent 247ba88 commit ee0a36d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
15 changes: 14 additions & 1 deletion pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
from pymc.backends.base import BaseTrace, IBaseTrace
from pymc.backends.ndarray import NDArray
from pymc.backends.zarr import ZarrTrace
from pymc.model import Model
from pymc.step_methods.compound import BlockedStep, CompoundStep

Expand Down Expand Up @@ -118,15 +119,27 @@ def _init_trace(

def init_traces(
*,
backend: TraceOrBackend | None,
backend: TraceOrBackend | ZarrTrace | None,
chains: int,
expected_length: int,
step: BlockedStep | CompoundStep,
initial_point: Mapping[str, np.ndarray],
model: Model,
trace_vars: list[TensorVariable] | None = None,
tune: int = 0,
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
"""Initialize a trace recorder for each chain."""
if isinstance(backend, ZarrTrace):
backend.init_trace(
chains=chains,
draws=expected_length - tune,
tune=tune,
step=step,
model=model,
vars=trace_vars,
test_point=initial_point,
)
return None, backend.straces
if HAS_MCB and isinstance(backend, Backend):
return init_chain_adapters(
backend=backend,
Expand Down
60 changes: 57 additions & 3 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
find_observations,
)
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
from pymc.backends.zarr import ZarrTrace
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
Expand Down Expand Up @@ -808,6 +809,7 @@ def joined_blas_limiter():
trace_vars=trace_vars,
initial_point=ip,
model=model,
tune=tune,
)

sample_args = {
Expand Down Expand Up @@ -890,7 +892,7 @@ def joined_blas_limiter():
# into a function to make it easier to test and refactor.
return _sample_return(
run=run,
traces=traces,
traces=trace if isinstance(trace, ZarrTrace) else traces,
tune=tune,
t_sampling=t_sampling,
discard_tuned_samples=discard_tuned_samples,
Expand All @@ -905,7 +907,7 @@ def joined_blas_limiter():
def _sample_return(
*,
run: RunType | None,
traces: Sequence[IBaseTrace],
traces: Sequence[IBaseTrace] | ZarrTrace,
tune: int,
t_sampling: float,
discard_tuned_samples: bool,
Expand All @@ -919,13 +921,65 @@ def _sample_return(
Final step of `pm.sampler`.
"""
if isinstance(traces, ZarrTrace):
# Split warmup from posterior samples
traces.split_warmup_groups()

# Set sampling time
traces._sampling_state.sampling_time[:] = t_sampling

# Compute number of actual draws per chain
total_draws_per_chain = traces._sampling_state.draw_idx[:]
n_chains = len(traces.straces)
desired_tune = traces.tuning_steps
desired_draw = len(traces.posterior.draw)
tuning_steps_per_chain = np.clip(total_draws_per_chain, 0, desired_tune)
draws_per_chain = total_draws_per_chain - tuning_steps_per_chain

total_n_tune = tuning_steps_per_chain.sum()
total_draws = draws_per_chain.sum()

_log.info(
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations '
f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) "
f"took {t_sampling:.0f} seconds."
)

if compute_convergence_checks or return_inferencedata:
idata = traces.to_inferencedata(save_warmup=not discard_tuned_samples)
log_likelihood = idata_kwargs.pop("log_likelihood", False)
if log_likelihood:
from pymc.stats.log_density import compute_log_likelihood

idata = compute_log_likelihood(
idata,
var_names=None if log_likelihood is True else log_likelihood,
extend_inferencedata=True,
model=model,
sample_dims=["chain", "draw"],
progressbar=False,
)

if compute_convergence_checks:
warns = run_convergence_checks(idata, model)
for warn in warns:
traces._sampling_state.global_warnings.append(warn)
log_warnings(warns)

if return_inferencedata:
# By default we drop the "warning" stat which contains `SamplerWarning`
# objects that can not be stored with `.to_netcdf()`.
if not keep_warning_stat:
return drop_warning_stat(idata)
return idata
return traces

# Pick and slice chains to keep the maximum number of samples
if discard_tuned_samples:
traces, length = _choose_chains(traces, tune)
else:
traces, length = _choose_chains(traces, 0)
mtrace = MultiTrace(traces)[:length]

# count the number of tune/draw iterations that happened
# ideally via the "tune" statistic, but not all samplers record it!
if "tune" in mtrace.stat_names:
Expand Down

0 comments on commit ee0a36d

Please sign in to comment.