From bfb268e6c55bcc893bf7f44f156cf0e9c1735497 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Thu, 21 Nov 2024 12:24:47 -0600 Subject: [PATCH] Update and test `_nx_cugraph._check_networkx_version` (#24) This fixes handling of the current dev version of networkx `"3.5rc0.dev0"`: https://github.com/networkx/networkx/blob/5c3e8beef128f532b536d2d4a9f7e309ed53416b/networkx/__init__.py#L11 Also, add testing since this is becoming exercised. Authors: - Erik Welch (https://github.com/eriknw) - Ralph Liu (https://github.com/nv-rliu) Approvers: - Ralph Liu (https://github.com/nv-rliu) - Rick Ratzel (https://github.com/rlratzel) - Kyle Edwards (https://github.com/KyleFromNVIDIA) URL: https://github.com/rapidsai/nx-cugraph/pull/24 --- _nx_cugraph/__init__.py | 9 +++++++-- nx_cugraph/tests/test_version.py | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/_nx_cugraph/__init__.py b/_nx_cugraph/__init__.py index 2c899f855..6b905e8db 100644 --- a/_nx_cugraph/__init__.py +++ b/_nx_cugraph/__init__.py @@ -342,16 +342,21 @@ def update_env_var(varname): return d -def _check_networkx_version() -> tuple[int, int] | tuple[int, int, int]: +def _check_networkx_version(nx_version=None) -> tuple[int, int] | tuple[int, int, int]: """Check the version of networkx and return ``(major, minor)`` version tuple.""" import re import warnings import networkx as nx - version_major, version_minor, *version_bug = nx.__version__.split(".")[:3] + if nx_version is None: + nx_version = nx.__version__ + version_major, version_minor, *version_bug = nx_version.split(".")[:3] if has_bug := bool(version_bug): version_bug = version_bug[0] + if "dev" in version_bug: + # For example: "3.5rc0.dev0" should give (3, 5) + has_bug = False if version_major != "3": warnings.warn( f"nx-cugraph version {__version__} is only known to work with networkx " diff --git a/nx_cugraph/tests/test_version.py b/nx_cugraph/tests/test_version.py index c45702b60..0a7985f2a 100644 --- a/nx_cugraph/tests/test_version.py +++ b/nx_cugraph/tests/test_version.py @@ -1,5 +1,8 @@ # Copyright (c) 2024, NVIDIA CORPORATION. +import pytest + +import _nx_cugraph import nx_cugraph @@ -10,3 +13,27 @@ def test_version_constants_are_populated(): # __version__ should always be non-empty assert isinstance(nx_cugraph.__version__, str) assert len(nx_cugraph.__version__) > 0 + + +def test_nx_ver(): + assert _nx_cugraph._check_networkx_version() == nx_cugraph._nxver + assert _nx_cugraph._check_networkx_version("3.4") == (3, 4) + assert _nx_cugraph._check_networkx_version("3.4.2") == (3, 4, 2) + assert _nx_cugraph._check_networkx_version("3.4rc0") == (3, 4) + assert _nx_cugraph._check_networkx_version("3.4.2rc1") == (3, 4, 2) + assert _nx_cugraph._check_networkx_version("3.5rc0.dev0") == (3, 5) + assert _nx_cugraph._check_networkx_version("3.5.1rc0.dev0") == (3, 5, 1) + assert _nx_cugraph._check_networkx_version("3.5.dev0") == (3, 5) + assert _nx_cugraph._check_networkx_version("3.5.1.dev0") == (3, 5, 1) + with pytest.raises(ValueError, match="not enough values to unpack"): + _nx_cugraph._check_networkx_version("3") + with pytest.raises(RuntimeWarning, match="does not work with networkx version"): + _nx_cugraph._check_networkx_version("3.4bad") + with pytest.warns( + UserWarning, match="only known to work with networkx versions 3.x" + ): + assert _nx_cugraph._check_networkx_version("2.2") == (2, 2) + with pytest.warns( + UserWarning, match="only known to work with networkx versions 3.x" + ): + assert _nx_cugraph._check_networkx_version("4.2") == (4, 2)