diff --git a/autotest/conftest.py b/autotest/conftest.py index d952441..6d69b98 100644 --- a/autotest/conftest.py +++ b/autotest/conftest.py @@ -1,4 +1,4 @@ from pathlib import Path -pytest_plugins = ["modflow_devtools.fixtures"] +pytest_plugins = ["modflow_devtools.fixtures", "modflow_devtools.snapshots"] project_root_path = Path(__file__).parent diff --git a/autotest/test_snapshots.py b/autotest/test_snapshots.py index 0c81aba..01d5f3e 100644 --- a/autotest/test_snapshots.py +++ b/autotest/test_snapshots.py @@ -2,10 +2,11 @@ from pathlib import Path import numpy as np +import pytest +from _pytest.config import ExitCode proj_root = Path(__file__).parents[1] module_path = Path(inspect.getmodulename(__file__)) -pytest_plugins = ["modflow_devtools.snapshots"] # activate snapshot fixtures snapshot_array = np.array([1.1, 2.2, 3.3]) snapshots_path = proj_root / "autotest" / "__snapshots__" @@ -61,3 +62,47 @@ def test_readable_text_array_snapshot(readable_array_snapshot): ), snapshot_array, ) + + +@pytest.mark.meta("test_snapshot_disable") +def test_snapshot_disable_inner(snapshot): + assert snapshot == "match this!" + + +@pytest.mark.parametrize("disable", [True, False]) +def test_snapshot_disable(disable): + inner_fn = test_snapshot_disable_inner.__name__ + args = [ + __file__, + "-v", + "-s", + "-k", + inner_fn, + "-M", + "test_snapshot_disable", + ] + if disable: + args.append("--snapshot-disable") + assert pytest.main(args) == (ExitCode.OK if disable else ExitCode.TESTS_FAILED) + + +@pytest.mark.meta("test_array_snapshot_disable") +def test_array_snapshot_disable_inner(array_snapshot): + assert array_snapshot == "can you match that?" + + +@pytest.mark.parametrize("disable", [True, False]) +def test_array_snapshot_disable(disable): + inner_fn = test_array_snapshot_disable_inner.__name__ + args = [ + __file__, + "-v", + "-s", + "-k", + inner_fn, + "-M", + "test_array_snapshot_disable", + ] + if disable: + args.append("--snapshot-disable") + assert pytest.main(args) == (ExitCode.OK if disable else ExitCode.TESTS_FAILED) diff --git a/modflow_devtools/snapshots.py b/modflow_devtools/snapshots.py index eed8776..9049861 100644 --- a/modflow_devtools/snapshots.py +++ b/modflow_devtools/snapshots.py @@ -8,10 +8,13 @@ syrupy = import_optional_dependency("syrupy") # ruff: noqa: E402 +from syrupy import __import_extension +from syrupy.assertion import SnapshotAssertion from syrupy.extensions.single_file import ( SingleFileSnapshotExtension, WriteMode, ) +from syrupy.location import PyTestLocation from syrupy.types import ( PropertyFilter, PropertyMatcher, @@ -90,19 +93,67 @@ def serialize( return np.array2string(data, threshold=np.inf) +class MatchAnything: + def __eq__(self, _): + return True + + # fixtures +@pytest.fixture(scope="session") +def snapshot_disable(pytestconfig) -> bool: + return pytestconfig.getoption("--snapshot-disable") + + @pytest.fixture -def array_snapshot(snapshot): - return snapshot.use_extension(BinaryArrayExtension) +def snapshot(request, snapshot_disable) -> "SnapshotAssertion": + return ( + MatchAnything() + if snapshot_disable + else SnapshotAssertion( + update_snapshots=request.config.option.update_snapshots, + extension_class=__import_extension(request.config.option.default_extension), + test_location=PyTestLocation(request.node), + session=request.session.config._syrupy, + ) + ) @pytest.fixture -def text_array_snapshot(snapshot): - return snapshot.use_extension(TextArrayExtension) +def array_snapshot(snapshot, snapshot_disable): + return ( + MatchAnything() + if snapshot_disable + else snapshot.use_extension(BinaryArrayExtension) + ) @pytest.fixture -def readable_array_snapshot(snapshot): - return snapshot.use_extension(ReadableArrayExtension) +def text_array_snapshot(snapshot, snapshot_disable): + return ( + MatchAnything() + if snapshot_disable + else snapshot.use_extension(TextArrayExtension) + ) + + +@pytest.fixture +def readable_array_snapshot(snapshot, snapshot_disable): + return ( + MatchAnything() + if snapshot_disable + else snapshot.use_extension(ReadableArrayExtension) + ) + + +# pytest config hooks + + +def pytest_addoption(parser): + parser.addoption( + "--snapshot-disable", + action="store_true", + default=False, + help="Disable snapshot comparisons.", + )