-
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement uncensor and forecast_timeseries model transformation
- Loading branch information
1 parent
05c9c24
commit 15b78db
Showing
2 changed files
with
265 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
from pymc import DiracDelta | ||
from pymc.distributions.censored import CensoredRV | ||
from pymc.distributions.timeseries import AR, AutoRegressiveRV | ||
from pymc.model import Model | ||
from pytensor import shared | ||
from pytensor.graph import FunctionGraph, node_rewriter | ||
from pytensor.graph.basic import get_var_by_name | ||
from pytensor.graph.rewriting.basic import in2out | ||
|
||
from pymc_experimental.utils.model_fgraph import ( | ||
ModelObservedRV, | ||
ModelValuedVar, | ||
fgraph_from_model, | ||
model_free_rv, | ||
model_from_fgraph, | ||
model_named, | ||
) | ||
|
||
__all__ = ( | ||
"forecast_timeseries", | ||
"uncensor", | ||
) | ||
|
||
|
||
@node_rewriter(tracks=[ModelValuedVar]) | ||
def uncensor_node_rewrite(fgraph, node): | ||
"""Rewrite that replaces censored variables by uncensored ones""" | ||
|
||
( | ||
censored_rv, | ||
value, | ||
*dims, | ||
) = node.inputs | ||
if not isinstance(censored_rv.owner.op, CensoredRV): | ||
return | ||
|
||
model_rv = node.outputs[0] | ||
base_rv = censored_rv.owner.inputs[0] | ||
uncensored_rv = node.op.make_node(base_rv, value, *dims).default_output() | ||
uncensored_rv.name = f"{model_rv.name}_uncensored" | ||
return [uncensored_rv] | ||
|
||
|
||
uncensor_rewrite = in2out(uncensor_node_rewrite) | ||
|
||
|
||
def uncensor(model: Model) -> Model: | ||
"""Replace censored variables in the model by uncensored equivalent. | ||
Replaced variables have the same name as original ones with an additional "_uncensored" suffix. | ||
.. code-block:: python | ||
import arviz as az | ||
import pymc as pm | ||
from pymc_experimental.model_transform import uncensor | ||
with pm.Model() as model: | ||
x = pm.Normal("x") | ||
dist_raw = pm.Normal.dist(x) | ||
y = pm.Censored("y", dist=dist_raw, lower=-1, upper=1, observed=[-1, 0.5, 1, 1, 1]) | ||
idata = pm.sample() | ||
with uncensor(model): | ||
idata_pp = pm.sample_posterior_predictive(idata, var_names=["y_uncensored"]) | ||
az.summary(idata_pp) | ||
""" | ||
fg = fgraph_from_model(model) | ||
|
||
(_, nodes_changed, *_) = uncensor_rewrite.apply(fg) | ||
if not nodes_changed: | ||
raise RuntimeError("No censored variables were replaced by uncensored counterparts") | ||
|
||
return model_from_fgraph(fg) | ||
|
||
|
||
@node_rewriter(tracks=[ModelValuedVar]) | ||
def forecast_timeseries_node_rewrite(fgraph: FunctionGraph, node): | ||
"""Rewrite that replaces censored variables by uncensored ones""" | ||
|
||
( | ||
timeseries_rv, | ||
value, | ||
*dims, | ||
) = node.inputs | ||
if not isinstance(timeseries_rv.owner.op, AutoRegressiveRV): | ||
return | ||
|
||
forecast_steps = get_var_by_name(fgraph.inputs, "forecast_steps_") | ||
if len(forecast_steps) != 1: | ||
return False | ||
|
||
forecast_steps = forecast_steps[0] | ||
|
||
op = timeseries_rv.owner.op | ||
model_rv = node.outputs[0] | ||
|
||
# We cannot reference the variable we are planning to replace | ||
# Or it will introduce circularities in the graph | ||
# FIXME: This special logic shouldn't be needed for ObservedRVs | ||
# but PyMC does not allow one to not resample observed. | ||
# We hack around by conditioning on the value variable directly, | ||
# even though that should not be part of the generative graph... | ||
if isinstance(node.op, ModelObservedRV): | ||
init_dist = DiracDelta.dist(value[..., -op.ar_order :]) | ||
else: | ||
cloned_model_rv = model_rv.owner.clone().default_output() | ||
fgraph.add_output(cloned_model_rv, import_missing=True) | ||
init_dist = DiracDelta.dist(cloned_model_rv[..., -op.ar_order :]) | ||
|
||
if isinstance(timeseries_rv.owner.op, AutoRegressiveRV): | ||
rhos, sigma, *_ = timeseries_rv.owner.inputs | ||
new_timeseries_rv = AR.rv_op( | ||
rhos=rhos, | ||
sigma=sigma, | ||
init_dist=init_dist, | ||
steps=forecast_steps, | ||
ar_order=op.ar_order, | ||
constant_term=op.constant_term, | ||
) | ||
|
||
new_name = f"{model_rv.name}_forecast" | ||
new_value = new_timeseries_rv.type(name=new_name) | ||
new_timeseries_rv = model_free_rv(new_timeseries_rv, new_value, transform=None) | ||
new_timeseries_rv.name = new_name | ||
|
||
# Import new variables into fgraph (value and RNG) | ||
fgraph.import_var(new_timeseries_rv, import_missing=True) | ||
|
||
return [new_timeseries_rv] | ||
|
||
|
||
forecast_timeseries_rewrite = in2out(forecast_timeseries_node_rewrite, ignore_newtrees=True) | ||
|
||
|
||
def forecast_timeseries(model: Model, forecast_steps: int) -> Model: | ||
"""Replace timeseries variables in the model by forecast that start at the last value. | ||
Replaced variables have the same name as original ones with an additional "_forecast" suffix. | ||
The function will fail if any variables with fixed static shape depend on the timeseries being replaced, | ||
and forecast_steps differs from the original timeseries steps. | ||
.. code-block:: python | ||
import pymc as pm | ||
from pymc_experimental.model_transform import forecast_timeseries | ||
with pm.Model() as model: | ||
rho = pm.Normal("rho") | ||
sigma = pm.HalfNormal("sigma") | ||
init_dist = pm.Normal.dist() | ||
y = pm.AR("y", init_dist=init_dist, rho=rho, sigma=sigma, observed=np.zeros(100,)) | ||
idata = pm.sample() | ||
forecast_model = forecast_timeseries(mode, forecast_steps=20) | ||
with forecast_model: | ||
idata_pp = pm.sample_posterior_predictive(idata, var_names=["y_forecast"]) | ||
az.summary(idata_pp) | ||
""" | ||
|
||
fg = fgraph_from_model(model) | ||
|
||
forecast_steps_sh = shared(forecast_steps, name="forecast_steps_") | ||
forecast_steps_sh = model_named(forecast_steps_sh) | ||
fg.add_output(forecast_steps_sh, import_missing=True) | ||
|
||
(_, nodes_changed, *_) = forecast_timeseries_rewrite.apply(fg) | ||
if not nodes_changed: | ||
raise RuntimeError("No timeseries were replaced by forecast counterparts") | ||
|
||
res = model_from_fgraph(fg) | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import arviz as az | ||
import numpy as np | ||
import pymc as pm | ||
import pytest | ||
|
||
from pymc_experimental.model_transform import forecast_timeseries, uncensor | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"transform, kwargs", | ||
[ | ||
(uncensor, dict()), | ||
(forecast_timeseries, dict(forecast_steps=20)), | ||
], | ||
) | ||
def test_transform_error(transform, kwargs): | ||
"""Test informative error is raised when the transform is not applicable to a model.""" | ||
with pm.Model() as model: | ||
x = pm.Normal("x") | ||
y = pm.Normal("y", x, observed=[0, 5, 10]) | ||
|
||
with pytest.raises(RuntimeError, match="No .* were replaced by .* counterparts"): | ||
transform(model, **kwargs) | ||
|
||
|
||
def test_uncensor(): | ||
with pm.Model() as model: | ||
x = pm.Normal("x") | ||
dist_raw = pm.Normal.dist(x) | ||
y = pm.Censored("y", dist=dist_raw, lower=-1, upper=1, observed=[0, 5, 10]) | ||
det = pm.Deterministic("det", y * 2) | ||
|
||
idata = az.from_dict({"x": np.zeros((2, 500))}) | ||
|
||
with uncensor(model): | ||
pp = pm.sample_posterior_predictive( | ||
idata, | ||
var_names=["y_uncensored", "det"], | ||
random_seed=18, | ||
).posterior_predictive | ||
|
||
assert np.any(np.abs(pp["y_uncensored"]) > 1) | ||
np.testing.assert_allclose(pp["y_uncensored"] * 2, pp["det"]) | ||
|
||
|
||
@pytest.mark.parametrize("observed", (True, False)) | ||
@pytest.mark.parametrize("ar_order", (1, 2)) | ||
def test_forecast_timeseries_ar(observed, ar_order): | ||
data_steps = 3 | ||
data = np.hstack((np.zeros(ar_order), (np.arange(data_steps) + 1) * 100.0)) | ||
with pm.Model() as model: | ||
rho = pm.Normal("rho", shape=(ar_order,)) | ||
sigma = pm.HalfNormal("sigma") | ||
init_dist = pm.Normal.dist(0, 1e-3) | ||
y = pm.AR( | ||
"y", | ||
init_dist=init_dist, | ||
rho=rho, | ||
sigma=sigma, | ||
observed=data if observed else None, | ||
steps=data_steps, | ||
) | ||
det = pm.Deterministic("det", y * 2) | ||
|
||
draws = (2, 50) | ||
# These rhos mean that all steps will be data[-1] for ar_order > 1 | ||
idata_dict = { | ||
"rho": np.full((*draws, ar_order), (0.1,) + (0,) * (ar_order - 1)), | ||
"sigma": np.full(draws, 1e-5), | ||
} | ||
if observed: | ||
idata = az.from_dict(idata_dict, observed_data={"y": data}) | ||
else: | ||
idata_dict["y"] = np.full((*draws, len(data)), data) | ||
idata = az.from_dict(idata_dict) | ||
|
||
forecast_steps = 5 | ||
with forecast_timeseries(model, forecast_steps=forecast_steps): | ||
pp = pm.sample_posterior_predictive( | ||
idata, | ||
var_names=["y_forecast", "det"], | ||
random_seed=50, | ||
).posterior_predictive | ||
|
||
expected = data[-1] / np.logspace(0, forecast_steps, forecast_steps + 1) | ||
expected = np.hstack((data[-ar_order:-1], expected)) | ||
np.testing.assert_allclose( | ||
pp["y_forecast"].values, np.full((*draws, forecast_steps + ar_order), expected), rtol=0.01 | ||
) | ||
np.testing.assert_allclose(pp["y_forecast"] * 2, pp["det"]) |