From 150fb0fba14f53dec651ae3a9f31d346bd68d69b Mon Sep 17 00:00:00 2001 From: "hasnain3257@gmail.com" Date: Fri, 24 Nov 2023 12:33:35 +0100 Subject: [PATCH] Bump PyMC minimum version requirement --- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- pymc_experimental/model/marginal_model.py | 2 -- pymc_experimental/utils/prior.py | 14 +++++++------- requirements.txt | 2 +- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index e3ea81f8..5ca1135b 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -11,6 +11,6 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.9.0 # CI was failing to resolve + - pymc>=5.10.0 # CI was failing to resolve - blackjax - scikit-learn diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 4fef73cf..6194fc8e 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -10,6 +10,6 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.9.0 # CI was failing to resolve + - pymc>=5.10.0 # CI was failing to resolve - blackjax - scikit-learn diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index 77ee14ab..89796c0c 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -461,8 +461,6 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): for i in range(len(marginalized_rv_domain)) ] else: - # Make sure this rewrite is registered - from pymc.pytensorf import local_remove_check_parameter def logp_fn(marginalized_rv_const, *non_sequences): return joint_logp_op(marginalized_rv_const, *non_sequences) diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 962b01bc..30d4e950 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -19,12 +19,12 @@ import numpy as np import pymc as pm import pytensor.tensor as pt -from pymc.logprob.transforms import RVTransform +from pymc.logprob.transforms import Transform class ParamCfg(TypedDict): name: str - transform: Optional[RVTransform] + transform: Optional[Transform] dims: Optional[Union[str, Tuple[str]]] @@ -44,14 +44,14 @@ class FlatInfo(TypedDict): info: List[VarInfo] -def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, RVTransform, str, Tuple]] = None): +def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, Transform, str, Tuple]] = None): if value is None: cfg = ParamCfg(name=key, transform=None, dims=None) elif isinstance(value, Tuple): cfg = ParamCfg(name=key, transform=None, dims=value) elif isinstance(value, str): cfg = ParamCfg(name=value, transform=None, dims=None) - elif isinstance(value, RVTransform): + elif isinstance(value, Transform): cfg = ParamCfg(name=key, transform=value, dims=None) else: cfg = value.copy() @@ -62,7 +62,7 @@ def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, RVTransform, str, Tup def _parse_args( - var_names: Sequence[str], **kwargs: Union[ParamCfg, RVTransform, str, Tuple] + var_names: Sequence[str], **kwargs: Union[ParamCfg, Transform, str, Tuple] ) -> Dict[str, ParamCfg]: results = dict() for var in var_names: @@ -133,7 +133,7 @@ def prior_from_idata( name="trace_prior_", *, var_names: Sequence[str] = (), - **kwargs: Union[ParamCfg, RVTransform, str, Tuple] + **kwargs: Union[ParamCfg, Transform, str, Tuple] ) -> Dict[str, pt.TensorVariable]: """ Create a prior from posterior using MvNormal approximation. @@ -153,7 +153,7 @@ def prior_from_idata( Inference data with posterior group var_names: Sequence[str] names of variables to take as is from the posterior - kwargs: Union[ParamCfg, RVTransform, str, Tuple] + kwargs: Union[ParamCfg, Transform, str, Tuple] names of variables with additional configuration, see more in Examples Examples diff --git a/requirements.txt b/requirements.txt index 2609234c..f0965caa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -pymc>=5.8.2 +pymc>=5.10.0 scikit-learn