Skip to content

Commit

Permalink
Update louvain_communities to match NetworkX 3.3 (added max_level) (#…
Browse files Browse the repository at this point in the history
…4177)

We already supported `max_level=`, and this was just upstreamed to networkx here: networkx/networkx#6909

Authors:
  - Erik Welch (https://github.com/eriknw)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4177
  • Loading branch information
eriknw authored Feb 26, 2024
1 parent 2c478fb commit 7a47ad0
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 7 deletions.
1 change: 0 additions & 1 deletion python/nx-cugraph/_nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@
},
"louvain_communities": {
"dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
"max_level : int, optional": "Upper limit of the number of macro-iterations (max: 500).",
},
"pagerank": {
"dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
Expand Down
19 changes: 14 additions & 5 deletions python/nx-cugraph/nx_cugraph/algorithms/community/louvain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.
import warnings

import networkx as nx
import pylibcugraph as plc

from nx_cugraph.convert import _to_undirected_graph
Expand All @@ -25,13 +26,21 @@

__all__ = ["louvain_communities"]

# max_level argument was added to NetworkX 3.3
if nx.__version__[:3] <= "3.2":
_max_level_param = {
"max_level : int, optional": (
"Upper limit of the number of macro-iterations (max: 500)."
)
}
else:
_max_level_param = {}


@not_implemented_for("directed")
@networkx_algorithm(
extra_params={
"max_level : int, optional": (
"Upper limit of the number of macro-iterations (max: 500)."
),
**_max_level_param,
**_dtype_param,
},
is_incomplete=True, # seed not supported; self-loops not supported
Expand All @@ -44,9 +53,9 @@ def louvain_communities(
weight="weight",
resolution=1,
threshold=0.0000001,
max_level=None,
seed=None,
*,
max_level=None,
dtype=None,
):
"""`seed` parameter is currently ignored, and self-loops are not yet supported."""
Expand Down Expand Up @@ -82,9 +91,9 @@ def _(
weight="weight",
resolution=1,
threshold=0.0000001,
max_level=None,
seed=None,
*,
max_level=None,
dtype=None,
):
# NetworkX allows both directed and undirected, but cugraph only allows undirected.
Expand Down
6 changes: 6 additions & 0 deletions python/nx-cugraph/nx_cugraph/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ def key(testpath):
): different_iteration_order,
}
)
elif nxver.minor >= 3:
xfail.update(
{
key("test_louvain.py:test_max_level"): louvain_different,
}
)

too_slow = "Too slow to run"
skip = {
Expand Down
10 changes: 9 additions & 1 deletion python/nx-cugraph/nx_cugraph/tests/test_match_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -14,10 +14,13 @@
import inspect

import networkx as nx
from packaging.version import parse

import nx_cugraph as nxcg
from nx_cugraph.utils import networkx_algorithm

nxver = parse(nx.__version__)


def test_match_signature_and_names():
"""Simple test to ensure our signatures and basic module layout match networkx."""
Expand All @@ -41,6 +44,11 @@ def test_match_signature_and_names():
else:
orig_func = dispatchable_func.orig_func

if nxver.major == 3 and nxver.minor <= 2 and name == "louvain_communities":
# The signature of louvain_communities changed in NetworkX 3.3, and
# we updated to match, so we skip this check in older versions.
continue

# Matching signatures?
orig_sig = inspect.signature(orig_func)
func_sig = inspect.signature(func)
Expand Down

0 comments on commit 7a47ad0

Please sign in to comment.