Skip to content

Commit

Permalink
Fix CI errors in statspace test suite (pymc-devs#251)
Browse files Browse the repository at this point in the history
* Ignore numpy depreciation warnings from statsmodel

* Squeeze result vector-matrix multiplication with (1, 1) matrix to avoid shape error in numpy 1.25.2

* Consolidate all project settings into `pyproject.toml`

* Delete unused`pytest.ini` and `setup.cfg`

* Remove unnecessary filtered warnings

* Change pathfinder jax import from deprecated pymc.sampling_jax to pymc.sampling.jax

* Skip pathfinder test if python < 3.10

* Add some comments to `pyproject.toml` to explain what warnings are being ignored and why
  • Loading branch information
jessegrabowski authored Oct 8, 2023
1 parent 3444ede commit e163a6f
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pymc_experimental/inference/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np
import pymc as pm
from pymc import modelcontext
from pymc.sampling_jax import get_jaxified_graph
from pymc.sampling.jax import get_jaxified_graph
from pymc.util import RandomSeed, _get_seeds_per_chain, get_default_varnames


Expand Down
1 change: 1 addition & 0 deletions pymc_experimental/tests/statespace/test_VARMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def test_VARMAX_param_counts_match_statsmodels(data, order, var):

@pytest.mark.parametrize("order", orders, ids=ids)
@pytest.mark.filterwarnings("ignore::statsmodels.tools.sm_exceptions.EstimationWarning")
@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_VARMAX_update_matches_statsmodels(data, order, rng):
p, q = order

Expand Down
4 changes: 2 additions & 2 deletions pymc_experimental/tests/statespace/utilities/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def simulate_from_numpy_model(mod, rng, param_dict, steps=100):
y = np.zeros(steps)

x[0] = x0
y[0] = Z @ x0
y[0] = (Z @ x0).squeeze()

if not np.allclose(H, 0):
y[0] += rng.multivariate_normal(mean=np.zeros(1), cov=H)
Expand All @@ -245,7 +245,7 @@ def simulate_from_numpy_model(mod, rng, param_dict, steps=100):
error = 0

x[t] = c + T @ x[t - 1] + innov
y[t] = d + Z @ x[t] + error
y[t] = (d + Z @ x[t] + error).squeeze()

return x, y

Expand Down
3 changes: 3 additions & 0 deletions pymc_experimental/tests/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

# TODO: Remove this filterwarning after pytensor uses jnp.prod instead of jnp.product
@pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.")
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="pymc.sampling.jax does not currently support python < 3.10"
)
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_pathfinder():
# Data of the Eight Schools Model
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ addopts = [
"--ignore=pymc_experimental/model_builder.py"
]

filterwarnings =[
"error",
# Raised by arviz when the model_builder class adds non-standard group names to InferenceData
"ignore::UserWarning:arviz.data.inference_data",

# bool8, find_common_type, cumproduct, and product had deprecation warnings added in numpy 1.25
'ignore:.*(\b(pkg_resources\.declare_namespace|np\.bool8|np\.find_common_type|cumproduct|product)\b).*:DeprecationWarning',
]

[tool.black]
line-length = 100
Expand All @@ -20,6 +28,7 @@ exclude_lines = [

[tool.isort]
profile = "black"
# lines_between_types = 1

[tool.nbqa.mutate]
isort = 1
Expand Down
7 changes: 0 additions & 7 deletions pytest.ini

This file was deleted.

11 changes: 0 additions & 11 deletions setup.cfg

This file was deleted.

0 comments on commit e163a6f

Please sign in to comment.