Skip to content

Commit

Permalink
Merge pull request #32 from JuBiotech/pymc5
Browse files Browse the repository at this point in the history
Drop PyMC3 compatibility in favor of PyMC v5
  • Loading branch information
michaelosthege authored Dec 19, 2022
2 parents 5fc74b3 + cf1ca51 commit 27ef1ea
Show file tree
Hide file tree
Showing 8 changed files with 386 additions and 214 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
pymc-version: ["without", "pymc>=4.0.0", '"pymc3>=3.11.5" "numpy<1.22"']
python-version: ["3.8", "3.9"]
pymc-version: ["without", "'pymc>=4.2.2,<5'", "'pymc>=5.0.0'"]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
uses: actions/setup-python@v4.3.1
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ jobs:
env:
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v1
uses: actions/setup-python@v4.3.1
with:
python-version: 3.7
python-version: 3.9
- name: Install dependencies
run: |
pip install -e .
Expand Down
2 changes: 1 addition & 1 deletion bletl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
NoMeasurementData,
)

__version__ = "1.1.3"
__version__ = "1.2.0"
218 changes: 121 additions & 97 deletions bletl/growth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,13 @@
import arviz
import calibr8
import numpy
import scipy.stats
from calibr8.utils import pm
import pymc as pm
from packaging import version

# Use the new ConstantData container if available,
# because it gives superior computational performance.
if hasattr(pm, "ConstantData"):
pmData = pm.ConstantData
else:
pmData = pm.Data


try:
import aesara.tensor as at
import pytensor.tensor as pt
except ModuleNotFoundError:
import theano.tensor as at
import aesara.tensor as pt


_log = logging.getLogger(__file__)
Expand Down Expand Up @@ -147,7 +138,7 @@ def sample(self, **kwargs) -> None:
return_inferencedata=True,
target_accept=0.95,
init="adapt_diag",
start=self.theta_map,
initvals=self.theta_map,
tune=500,
draws=500,
)
Expand All @@ -160,12 +151,14 @@ def sample(self, **kwargs) -> None:
def _make_random_walk(
name: str,
*,
init_dist: pt.TensorVariable,
mu: float = 0,
sigma: float,
nu: float = 1,
length: int,
student_t: bool,
initval: numpy.ndarray = None,
dims: typing.Optional[str] = None,
):
"""Create a random walk with either a Normal or Student-t distribution.
Expand All @@ -176,13 +169,12 @@ def _make_random_walk(
----------
name : str
Name of the random walk variable.
init_dist
A random variable to use as the prior for innovations.
mu : float, array-like
Mean of the random walk.
If a vector is passed, only the first element should be nonzero,
otherwise the random walk will drift systematically.
sigma : float, array-like
Standard deviation (Normal) or scale (StudentT) parameter.
A vector may be passed to customize, for example the prior at the start.
nu : float, array-like
Degree of freedom for the StudentT distribution - only used when `student_t == True`.
length : int
Expand All @@ -193,6 +185,8 @@ def _make_random_walk(
initval : numpy.ndarray
Initial values for the RandomWalk variable.
If set, PyMC uses these values as start points for MAP optimization and MCMC sampling.
dims
Optional dims to be forwarded to the `RandomWalk`.
Returns
-------
Expand All @@ -201,43 +195,23 @@ def _make_random_walk(
"""
pmversion = version.parse(pm.__version__)

# Adapt to rename of the testval→initval kwarg
if pmversion <= version.parse("3.11.4"):
initval_kwarg = "testval"
else:
initval_kwarg = "initval"

if pmversion < version.parse("4.0.0b1") and not student_t:
# Use the gaussian random walk distribution directly.
return pm.GaussianRandomWalk(
**{
"name": name,
"mu": mu,
"sigma": sigma,
"shape": (length,),
initval_kwarg: initval,
}
)
else:
# Create the random walk manually.
rv_kwargs = {
"name": f"{name}__diff_",
"mu": mu,
"sigma": sigma,
"shape": (length,),
# Since the initval refers to the random walk, but we're creating it
# using the cumsum of an RV, we need to do numpy.diff to get an initial
# value for the RV from the initial value of the random walk.
initval_kwarg: numpy.diff(initval, prepend=0) if initval is not None else None,
}

if student_t:
rv_cls = pm.StudentT
rv_kwargs["nu"] = nu
else:
rv_cls = pm.Normal
if pmversion < version.parse("4.2.2"):
raise NotImplementedError("PyMC versions <4.2.2 are no longer supported.")

return pm.Deterministic(name, at.cumsum(rv_cls(**rv_kwargs)))
if student_t:
innov_dist = pm.StudentT.dist(mu=mu, sigma=sigma, nu=nu)
else:
innov_dist = pm.Normal.dist(mu=mu, sigma=sigma)

rw = pm.RandomWalk(
name,
init_dist=init_dist,
innovation_dist=innov_dist,
steps=length - 1,
initval=initval,
dims=dims,
)
return rw


def _get_smoothed_mu(
Expand Down Expand Up @@ -348,8 +322,6 @@ def fit_mu_t(
t_switchpoints_known = numpy.sort(list(switchpoints.keys()))
if student_t is None:
student_t = len(switchpoints) == 0
# build a dict of known switchpoint begin cycle indices so they can be ignored in autodetection
c_switchpoints_known = [0]

# Use a smoothed, diff-based growth rate on the backscatter to initialize the optimization.
# These values are still everything but high-quality estimates of the growth rate,
Expand All @@ -361,23 +333,21 @@ def fit_mu_t(
TD = len(t_data)
TS = len(t_segments)

# The mu_prior parameter is used to initialize the random walk at a more realistic growth rate.
# This can become necessary when there was no lag phase.
if mu_prior != 0:
mu_prior = numpy.array([mu_prior] + [0] * (TS - 1))
# Override guess with user-provided mu_prior for nonzero starting points.
mu_guess[mu_prior != 0] = mu_prior[mu_prior != 0]

# build PyMC model
coords = {
"timepoint": numpy.arange(TD),
"segment": numpy.arange(TS),
}
with pm.Model(coords=coords) as pmodel:
pmData("known_switchpoints", t_switchpoints_known)
pmData("t_data", t_data, dims="timepoint")
pmData("t_segments", t_segments, dims="segment")
dt = pmData("dt", numpy.diff(t_data), dims="segment")
pm.ConstantData("known_switchpoints", t_switchpoints_known)
pm.ConstantData("t_data", t_data, dims="timepoint")
pm.ConstantData("t_segments", t_segments, dims="segment")
dt = pm.ConstantData("dt", numpy.diff(t_data), dims="segment")

# The init dist for the random walk is where each segment starts.
# Here we center it on the user-provided mu_prior,
# taking the absolute of it (+0.05 safety margin to avoid 0) as the scale.
init_dist = pm.Normal.dist(mu=mu_prior, sigma=pt.abs(mu_prior) + 0.05)

if len(t_switchpoints_known) > 0:
_log.info(
Expand All @@ -394,7 +364,8 @@ def fit_mu_t(
mu_segments.append(
_make_random_walk(
name,
mu=mu_prior[slc],
init_dist=init_dist,
mu=0,
sigma=drift_scale,
nu=nu,
length=i_len,
Expand All @@ -403,49 +374,51 @@ def fit_mu_t(
)
)
i_from += i_len
# remember the index to ignore it in potential autodetection
c_switchpoints_known.append(i_from)
# the last segment until the end
i_len = len(t[i_from:]) - 1
name = f"mu_phase_{len(mu_segments)}"
slc = slice(i_from, None)
mu_segments.append(
_make_random_walk(
name,
mu_prior[slc],
init_dist=init_dist,
mu=0,
sigma=drift_scale,
nu=nu,
length=i_len,
student_t=student_t,
initval=mu_guess[slc],
)
)
mu_t = pm.Deterministic("mu_t", at.concatenate(mu_segments), dims="segment")
mu_t = pm.Deterministic("mu_t", pt.concatenate(mu_segments), dims="segment")
else:
_log.info(
"Creating model without switchpoints. StudentT=%b", len(t_switchpoints_known), student_t
)
mu_t = _make_random_walk(
"mu_t",
mu=mu_prior,
init_dist=init_dist,
mu=0,
sigma=drift_scale,
nu=nu,
length=TS,
student_t=student_t,
initval=mu_guess,
dims="segment",
)

X0 = pm.LogNormal("X0", mu=numpy.log(x0_prior), sigma=1)
Xt = pm.Deterministic(
"X",
at.concatenate([X0[None], X0 * pm.math.exp(at.extra_ops.cumsum(mu_t * dt))]),
pt.concatenate([X0[None], X0 * pm.math.exp(pt.extra_ops.cumsum(mu_t * dt))]),
dims="timepoint",
)
calibration_model.loglikelihood(
x=Xt,
y=pmData("backscatter", y, dims=("timepoint",)),
y=pm.ConstantData("backscatter", y, dims=("timepoint",)),
replicate_id=replicate_id,
dependent_key=calibration_model.dependent_key,
dims="timepoint",
)

# MAP fit
Expand All @@ -454,31 +427,14 @@ def fit_mu_t(

# with StudentT random walks, switchpoints can be autodetected
if student_t:
# first CDF values at all mu_t elements
cdf_evals = []
for rvname in sorted(theta_map.keys()):
if "__diff_" in rvname:
rv = pmodel[rvname]
# for every µ, find out where it lies in the CDF of the StudentT prior distribution
cdf_evals += list(
scipy.stats.t.cdf(
x=theta_map[rvname],
loc=rv.owner.inputs[3].eval(),
scale=rv.owner.inputs[4].eval(),
df=rv.owner.inputs[2].eval(),
)
)
cdf_evals = numpy.array(cdf_evals)
# filter for the elements that lie outside of the [0.005, 0.995] interval
significance_mask = numpy.logical_or(
cdf_evals < (switchpoint_prob / 2),
cdf_evals > (1 - switchpoint_prob / 2),
switchpoints_detected = detect_switchpoints(
switchpoint_prob,
t_data,
pmodel,
theta_map,
)
# add these autodetected timepoints to the switchpoints-dict
# (ignore the first timepoint)
for c_switch, (t_switch, is_switchpoint) in enumerate(zip(t_data, significance_mask[1:])):
if is_switchpoint and c_switch not in c_switchpoints_known:
switchpoints[t_switch] = "detected"
# Known switchpoints override detected ones 👇
switchpoints = {**switchpoints_detected, **switchpoints}

# bundle up all relevant variables into a result object
result = GrowthRateResult(
Expand All @@ -500,3 +456,71 @@ def fit_mu_t(
result.sample(draws=mcmc_samples)

return result


def detect_switchpoints(
switchpoint_prob: float,
t_data: typing.Sequence[float],
pmodel: pm.Model,
theta_map: typing.Dict[str, numpy.ndarray],
) -> typing.Dict[float, str]:
"""Helper function to detect switchpoints from a fitted random walk.
Parameters
----------
switchpoint_prob
Probability threshold for detecting switchpoints.
Random walk innovations with a prior probability less than this
will be classified as switchpoints.
t_data
Time values corresponding to the random walk steps.
pmodel
The PyMC model containing `"mu_t*"` random walks.
theta_map
MAP estimate of the model.
Returns
-------
switchpoints
Dictionary of switchpoints with
keys being the time point and
values `"detected"`.
"""
# first CDF values at all mu_t elements
cdf_evals = []
for rvname in sorted(theta_map.keys()):
if rvname not in pmodel.named_vars:
continue
# The random walk may be split in multiple segments.
# We can identify a segment from the RVOp type that created it.
rv = pmodel[rvname]
if rv.owner is None:
continue
if isinstance(rv.owner.op, pm.RandomWalk.rv_type):
# Get a handle on the innovation dist so we can evaluate prior CDFs.
innov_dist = rv.owner.inputs[1]
# Calculate the innovations from the MAP estimate of the points.
# This gives only the deltas between the points, so the 0th element
# in the new vector corresponds to the segment between the 0st and 1nd point.
innov = numpy.diff(theta_map[rvname])
# Now we can evaluate the CDFs of the innovations.
logcdfs = pm.logcdf(innov_dist, innov).eval()
# We define switchpoints based on the time of the point with an extreme CDF value.
# To get our <number of segments> length vector to align with the <number of points>,
# we prepend a 0.5 as a placeholder for the CDF of the initial point of the random walk.
cdf_evals += [0.5, *numpy.exp(logcdfs)]
cdf_evals = numpy.array(cdf_evals)
if len(cdf_evals) != len(t_data) - 1:
raise Exception(
f"Failed to find all random walk segments. Found {len(cdf_evals)}, expected {len(t_data) - 1}."
)
# Filter for the elements that lie outside of the [0.005, 0.995] interval (if switchpoint_prob=0.01).
significance_mask = numpy.logical_or(
cdf_evals < (switchpoint_prob / 2),
cdf_evals > (1 - switchpoint_prob / 2),
)
# Collect switchpoint information from points with significant CDF values.
# Here we don't need to filter known switchpoints, because these correspond to the first
# point in each random walk, for which we assigned non-significant 0.5 CDF placeholders above.
switchpoints = {t: "detected" for t, is_switchpoint in zip(t_data, significance_mask) if is_switchpoint}
return switchpoints
Loading

0 comments on commit 27ef1ea

Please sign in to comment.