diff --git a/_nx_cugraph/__init__.py b/_nx_cugraph/__init__.py index 2c899f855..6b905e8db 100644 --- a/_nx_cugraph/__init__.py +++ b/_nx_cugraph/__init__.py @@ -342,16 +342,21 @@ def update_env_var(varname): return d -def _check_networkx_version() -> tuple[int, int] | tuple[int, int, int]: +def _check_networkx_version(nx_version=None) -> tuple[int, int] | tuple[int, int, int]: """Check the version of networkx and return ``(major, minor)`` version tuple.""" import re import warnings import networkx as nx - version_major, version_minor, *version_bug = nx.__version__.split(".")[:3] + if nx_version is None: + nx_version = nx.__version__ + version_major, version_minor, *version_bug = nx_version.split(".")[:3] if has_bug := bool(version_bug): version_bug = version_bug[0] + if "dev" in version_bug: + # For example: "3.5rc0.dev0" should give (3, 5) + has_bug = False if version_major != "3": warnings.warn( f"nx-cugraph version {__version__} is only known to work with networkx " diff --git a/nx_cugraph/tests/test_version.py b/nx_cugraph/tests/test_version.py index c45702b60..0a7985f2a 100644 --- a/nx_cugraph/tests/test_version.py +++ b/nx_cugraph/tests/test_version.py @@ -1,5 +1,8 @@ # Copyright (c) 2024, NVIDIA CORPORATION. +import pytest + +import _nx_cugraph import nx_cugraph @@ -10,3 +13,27 @@ def test_version_constants_are_populated(): # __version__ should always be non-empty assert isinstance(nx_cugraph.__version__, str) assert len(nx_cugraph.__version__) > 0 + + +def test_nx_ver(): + assert _nx_cugraph._check_networkx_version() == nx_cugraph._nxver + assert _nx_cugraph._check_networkx_version("3.4") == (3, 4) + assert _nx_cugraph._check_networkx_version("3.4.2") == (3, 4, 2) + assert _nx_cugraph._check_networkx_version("3.4rc0") == (3, 4) + assert _nx_cugraph._check_networkx_version("3.4.2rc1") == (3, 4, 2) + assert _nx_cugraph._check_networkx_version("3.5rc0.dev0") == (3, 5) + assert _nx_cugraph._check_networkx_version("3.5.1rc0.dev0") == (3, 5, 1) + assert _nx_cugraph._check_networkx_version("3.5.dev0") == (3, 5) + assert _nx_cugraph._check_networkx_version("3.5.1.dev0") == (3, 5, 1) + with pytest.raises(ValueError, match="not enough values to unpack"): + _nx_cugraph._check_networkx_version("3") + with pytest.raises(RuntimeWarning, match="does not work with networkx version"): + _nx_cugraph._check_networkx_version("3.4bad") + with pytest.warns( + UserWarning, match="only known to work with networkx versions 3.x" + ): + assert _nx_cugraph._check_networkx_version("2.2") == (2, 2) + with pytest.warns( + UserWarning, match="only known to work with networkx versions 3.x" + ): + assert _nx_cugraph._check_networkx_version("4.2") == (4, 2)