From 7a57e01d9936f69dfc57b47b4b86f32ad66f238c Mon Sep 17 00:00:00 2001 From: Jaron Lee Date: Wed, 5 Apr 2023 20:40:09 -0400 Subject: [PATCH 1/7] add diungraph Signed-off-by: Jaron Lee --- pywhy_graphs/classes/__init__.py | 2 +- .../classes/{cpdag.py => diungraph.py} | 230 ++++++++++++------ 2 files changed, 161 insertions(+), 71 deletions(-) rename pywhy_graphs/classes/{cpdag.py => diungraph.py} (67%) diff --git a/pywhy_graphs/classes/__init__.py b/pywhy_graphs/classes/__init__.py index 3a2fe76d4..ee36f646a 100644 --- a/pywhy_graphs/classes/__init__.py +++ b/pywhy_graphs/classes/__init__.py @@ -1,6 +1,6 @@ from . import timeseries from .admg import ADMG -from .cpdag import CPDAG +from .diungraph import CG, CPDAG from .intervention import IPAG, AugmentedGraph, PsiPAG from .pag import PAG from .timeseries import ( diff --git a/pywhy_graphs/classes/cpdag.py b/pywhy_graphs/classes/diungraph.py similarity index 67% rename from pywhy_graphs/classes/cpdag.py rename to pywhy_graphs/classes/diungraph.py index 960d96ace..60a746b24 100644 --- a/pywhy_graphs/classes/cpdag.py +++ b/pywhy_graphs/classes/diungraph.py @@ -8,67 +8,8 @@ from .base import AncestralMixin, ConservativeMixin -class CPDAG(pywhy_nx.MixedEdgeGraph, AncestralMixin, ConservativeMixin): - """Completed partially directed acyclic graphs (CPDAG). - - CPDAGs generalize causal DAGs by allowing undirected edges. - Undirected edges imply uncertainty in the orientation of the causal - relationship. For example, ``A - B``, can be ``A -> B`` or ``A <- B``, - allowing for a Markov equivalence class of DAGs for each CPDAG. - - Parameters - ---------- - incoming_directed_edges : input directed edges (optional, default: None) - Data to initialize directed edges. All arguments that are accepted - by `networkx.DiGraph` are accepted. - incoming_undirected_edges : input undirected edges (optional, default: None) - Data to initialize undirected edges. All arguments that are accepted - by `networkx.Graph` are accepted. - directed_edge_name : str - The name for the directed edges. By default 'directed'. - undirected_edge_name : str - The name for the directed edges. By default 'undirected'. - attr : keyword arguments, optional (default= no attributes) - Attributes to add to graph as key=value pairs. - - See Also - -------- - networkx.DiGraph - networkx.Graph - pywhy_graphs.ADMG - pywhy_graphs.networkx.MixedEdgeGraph - - Notes - ----- - CPDAGs are Markov equivalence class of causal DAGs. The implicit assumption in - these causal graphs are the Structural Causal Model (or SCM) is Markovian, inducing - causal sufficiency, where there is no unobserved latent confounder. This allows CPDAGs - to be learned from score-based (such as the "GES" algorithm) and constraint-based - (such as the PC algorithm) approaches for causal structure learning. - - One should not use CPDAGs if they suspect their data has unobserved latent confounders. - - **Edge Type Subgraphs** - - The data structure underneath the hood is stored in two networkx graphs: - ``networkx.Graph`` and ``networkx.DiGraph`` to represent the non-directed - edges and directed edges. Non-directed edges in an CPDAG can be present as - undirected edges standing for uncertainty in which directino the directed - edge is in. - - - Directed edges (<-, ->, indicating causal relationship) = `networkx.DiGraph` - The subgraph of directed edges may be accessed by the - `CPDAG.sub_directed_graph`. Their edges in networkx format can be - accessed by `CPDAG.directed_edges` and the corresponding name of the - edge type by `CPDAG.directed_edge_name`. - - Undirected edges (--, indicating uncertainty) = `networkx.Graph` - The subgraph of undirected edges may be accessed by the - `CPDAG.sub_undirected_graph`. Their edges in networkx format can be - accessed by `CPDAG.undirected_edges` and the corresponding name of the - edge type by `CPDAG.undirected_edge_name`. - - By definition, no cycles may exist due to the directed edges. - """ +class DiUnGraph(pywhy_nx.MixedEdgeGraph, AncestralMixin): + """ """ def __init__( self, @@ -85,15 +26,6 @@ def __init__( self._directed_name = directed_edge_name self._undirected_name = undirected_edge_name - from pywhy_graphs import is_valid_mec_graph - - # check that construction of PAG was valid - is_valid_mec_graph(self) - - # extended patterns store unfaithful triples - # these can be used for conservative structure learning algorithm - self._unfaithful_triples: Dict[FrozenSet[Node], None] = dict() - @property def undirected_edge_name(self) -> str: """Name of the undirected edge internal graph.""" @@ -184,6 +116,90 @@ def possible_parents(self, n: Node) -> Iterator[Node]: """ return self.sub_undirected_graph().neighbors(n) + +class CPDAG(DiUnGraph, ConservativeMixin): + """Completed partially directed acyclic graphs (CPDAG). + + CPDAGs generalize causal DAGs by allowing undirected edges. + Undirected edges imply uncertainty in the orientation of the causal + relationship. For example, ``A - B``, can be ``A -> B`` or ``A <- B``, + allowing for a Markov equivalence class of DAGs for each CPDAG. + + Parameters + ---------- + incoming_directed_edges : input directed edges (optional, default: None) + Data to initialize directed edges. All arguments that are accepted + by `networkx.DiGraph` are accepted. + incoming_undirected_edges : input undirected edges (optional, default: None) + Data to initialize undirected edges. All arguments that are accepted + by `networkx.Graph` are accepted. + directed_edge_name : str + The name for the directed edges. By default 'directed'. + undirected_edge_name : str + The name for the directed edges. By default 'undirected'. + attr : keyword arguments, optional (default= no attributes) + Attributes to add to graph as key=value pairs. + + See Also + -------- + networkx.DiGraph + networkx.Graph + pywhy_graphs.ADMG + pywhy_graphs.networkx.MixedEdgeGraph + + Notes + ----- + CPDAGs are Markov equivalence class of causal DAGs. The implicit assumption in + these causal graphs are the Structural Causal Model (or SCM) is Markovian, inducing + causal sufficiency, where there is no unobserved latent confounder. This allows CPDAGs + to be learned from score-based (such as the "GES" algorithm) and constraint-based + (such as the PC algorithm) approaches for causal structure learning. + + One should not use CPDAGs if they suspect their data has unobserved latent confounders. + + **Edge Type Subgraphs** + + The data structure underneath the hood is stored in two networkx graphs: + ``networkx.Graph`` and ``networkx.DiGraph`` to represent the non-directed + edges and directed edges. + + - Directed edges (<-, ->, indicating causal relationship) = `networkx.DiGraph` + The subgraph of directed edges may be accessed by the + `CPDAG.sub_directed_graph`. Their edges in networkx format can be + accessed by `CPDAG.directed_edges` and the corresponding name of the + edge type by `CPDAG.directed_edge_name`. + - Undirected edges (--, indicating uncertainty) = `networkx.Graph` + The subgraph of undirected edges may be accessed by the + `CPDAG.sub_undirected_graph`. Their edges in networkx format can be + accessed by `CPDAG.undirected_edges` and the corresponding name of the + edge type by `CPDAG.undirected_edge_name`. + + By definition, no cycles may exist due to the directed edges. + """ + + def __init__( + self, + incoming_directed_edges=None, + incoming_undirected_edges=None, + directed_edge_name: str = "directed", + undirected_edge_name: str = "undirected", + **attr, + ): + super().__init__( + incoming_directed_edges=incoming_directed_edges, + incoming_undirected_edges=incoming_undirected_edges, + directed_edge_name=directed_edge_name, + undirected_edge_name=undirected_edge_name, + ) + from pywhy_graphs import is_valid_mec_graph + + # check that construction of PAG was valid + is_valid_mec_graph(self) + + # extended patterns store unfaithful triples + # these can be used for conservative structure learning algorithm + self._unfaithful_triples: Dict[FrozenSet[Node], None] = dict() + def add_edge(self, u_of_edge, v_of_edge, edge_type="all", **attr): from pywhy_graphs.algorithms.generic import _check_adding_cpdag_edge @@ -200,3 +216,77 @@ def add_edges_from(self, ebunch_to_add, edge_type, **attr): self, u_of_edge=u_of_edge, v_of_edge=v_of_edge, edge_type=edge_type ) return super().add_edges_from(ebunch_to_add, edge_type, **attr) + + +class CG(DiUnGraph): + """Chain Graphs (CG). + + Chain graphs represent a generalization of DAGs and undirected graphs. + Undirected edges ``A - B`` in a chain graph represent a symmetric association of + two variables due to processes such as dynamic feedback (where ``A`` + influences ``B`` and vice versa) or an artefact of selection bias (where the selection + of the sample induces association between ``A`` and ``B``). + + + The implementation supports representation of both Lauritzen-Wermuth-Frydenberg (LWF) + and Andersen-Madigan-Perlman (AMP) chain graphs. + + + Parameters + ---------- + incoming_directed_edges : input directed edges (optional, default: None) + Data to initialize directed edges. All arguments that are accepted + by `networkx.DiGraph` are accepted. + incoming_undirected_edges : input undirected edges (optional, default: None) + Data to initialize undirected edges. All arguments that are accepted + by `networkx.Graph` are accepted. + directed_edge_name : str + The name for the directed edges. By default 'directed'. + undirected_edge_name : str + The name for the directed edges. By default 'undirected'. + attr : keyword arguments, optional (default= no attributes) + Attributes to add to graph as key=value pairs. + + See Also + -------- + networkx.DiGraph + networkx.Graph + pywhy_graphs.ADMG + pywhy_graphs.networkx.MixedEdgeGraph + + Notes + ----- + **Edge Type Subgraphs** + + The data structure underneath the hood is stored in two networkx graphs: + ``networkx.Graph`` and ``networkx.DiGraph`` to represent the non-directed + edges and directed edges. + + - Directed edges (<-, ->, indicating causal relationship) = `networkx.DiGraph` + The subgraph of directed edges may be accessed by the + `CG.sub_directed_graph`. Their edges in networkx format can be + accessed by `CG.directed_edges` and the corresponding name of the + edge type by `CG.directed_edge_name`. + - Undirected edges (--, indicating uncertainty) = `networkx.Graph` + The subgraph of undirected edges may be accessed by the + `CG.sub_undirected_graph`. Their edges in networkx format can be + accessed by `CG.undirected_edges` and the corresponding name of the + edge type by `CG.undirected_edge_name`. + + By definition, no cycles may exist due to the directed edges. + """ + + def __init__( + self, + incoming_directed_edges=None, + incoming_undirected_edges=None, + directed_edge_name: str = "directed", + undirected_edge_name: str = "undirected", + **attr, + ): + super().__init__( + incoming_directed_edges=incoming_directed_edges, + incoming_undirected_edges=incoming_undirected_edges, + directed_edge_name=directed_edge_name, + undirected_edge_name=undirected_edge_name, + ) From 328a656323e0df3a3fe3a0d9627e975030af4bd9 Mon Sep 17 00:00:00 2001 From: Jaron Lee Date: Wed, 5 Apr 2023 21:04:23 -0400 Subject: [PATCH 2/7] fix failing test Signed-off-by: Jaron Lee --- pywhy_graphs/classes/diungraph.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pywhy_graphs/classes/diungraph.py b/pywhy_graphs/classes/diungraph.py index 60a746b24..1ebc2fc86 100644 --- a/pywhy_graphs/classes/diungraph.py +++ b/pywhy_graphs/classes/diungraph.py @@ -225,7 +225,7 @@ class CG(DiUnGraph): Undirected edges ``A - B`` in a chain graph represent a symmetric association of two variables due to processes such as dynamic feedback (where ``A`` influences ``B`` and vice versa) or an artefact of selection bias (where the selection - of the sample induces association between ``A`` and ``B``). + of the sample induces association between ``A`` and ``B``) [1]_. The implementation supports representation of both Lauritzen-Wermuth-Frydenberg (LWF) @@ -247,6 +247,16 @@ class CG(DiUnGraph): attr : keyword arguments, optional (default= no attributes) Attributes to add to graph as key=value pairs. + References + ---------- + .. [1] Lauritzen, Steffen L., and Thomas S. Richardson. "Chain + graph models and their causal interpretations." Journal of the + Royal Statistical Society: Series B (Statistical Methodology) + 64.3 (2002): 321-348. + + + + See Also -------- networkx.DiGraph From 1d3e990bf7455b5d728847af1c31ab4e5ae6a987 Mon Sep 17 00:00:00 2001 From: Jaron Lee Date: Wed, 5 Apr 2023 21:30:08 -0400 Subject: [PATCH 3/7] template for is_valid functions Signed-off-by: Jaron Lee --- pywhy_graphs/algorithms/generic.py | 45 ++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index b98caaf5d..a41117905 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -333,3 +333,48 @@ def _single_shortest_path_early_stop(G, firstlevel, paths, cutoff, join, valid_p nextlevel[w] = 1 level += 1 return paths + + +def is_valid_cg(G: DiUnGraph) -> bool: + """Check G is a valid Chain Graph (CG). + + A valid CG is one where no undirected edge can be oriented to form + a directed cycle. + + Parameters + ---------- + G : DiUnGraph + The graph with directed and undirected edges. + + Returns + ------- + bool + Whether G is a valid CG. + + Notes + ----- + """ + pass + + +def is_valid_cpdag(G: DiUnGraph) -> bool: + """Check G is a valid CPDAG. + + A valid CPDAG is one where each pair of nodes have + at most one edge between them and the internal graph of directed edges + do not form cycles. + + Parameters + ---------- + G : DiUnGraph + The graph with directed and undirected edges. + + Returns + ------- + bool + Whether G is a valid CPDAG. + + Notes + ----- + """ + pass From 7800aee3fe95eb27aa442cd82a2487780aad2222 Mon Sep 17 00:00:00 2001 From: Jaron Lee Date: Sun, 16 Apr 2023 00:07:27 -0400 Subject: [PATCH 4/7] add framework for chain graph function and tests Signed-off-by: Jaron Lee --- pywhy_graphs/__init__.py | 1 + pywhy_graphs/algorithms/__init__.py | 1 + pywhy_graphs/algorithms/cg.py | 77 ++++++++++++++++ pywhy_graphs/algorithms/generic.py | 45 ---------- pywhy_graphs/algorithms/tests/test_cg.py | 108 +++++++++++++++++++++++ pywhy_graphs/classes/diungraph.py | 29 +++++- 6 files changed, 215 insertions(+), 46 deletions(-) create mode 100644 pywhy_graphs/algorithms/cg.py create mode 100644 pywhy_graphs/algorithms/tests/test_cg.py diff --git a/pywhy_graphs/__init__.py b/pywhy_graphs/__init__.py index 6b1b212e3..ec84cf71a 100644 --- a/pywhy_graphs/__init__.py +++ b/pywhy_graphs/__init__.py @@ -1,6 +1,7 @@ from ._version import __version__ # noqa: F401 from .classes import ( ADMG, + CG, CPDAG, PAG, AugmentedGraph, diff --git a/pywhy_graphs/algorithms/__init__.py b/pywhy_graphs/algorithms/__init__.py index 4698bf3da..22027575c 100644 --- a/pywhy_graphs/algorithms/__init__.py +++ b/pywhy_graphs/algorithms/__init__.py @@ -1,3 +1,4 @@ from .cyclic import * # noqa: F403 from .generic import * # noqa: F403 from .pag import * # noqa: F403 +from .cg import * # noqa: F403 diff --git a/pywhy_graphs/algorithms/cg.py b/pywhy_graphs/algorithms/cg.py new file mode 100644 index 000000000..a22f979b6 --- /dev/null +++ b/pywhy_graphs/algorithms/cg.py @@ -0,0 +1,77 @@ +from collections import deque + +import networkx as nx +import numpy as np + +from pywhy_graphs import CG + +__all__ = ["is_valid_cg"] + + +def is_valid_cg(graph: CG): + """ + Checks if a supplied chain graph is valid. + + This implements the original defintion of a (Lauritzen Wermuth Frydenberg) chain graph as + presented in [1]_. + + Define a cycle as a series of nodes X_1 -o X_2 ... X_n -o X_1 where the edges may be directed or + undirected. Note that directed edges in a cycle must all be aligned in the same direction. A + chain graph may only contain cycles consisting of only undirected edges. Equivalently, a chain + graph does not contain any cycles with one or more directed edges. + + Parameters + __________ + graph : CG + The graph. + + Returns + _______ + is_valid : bool + Whether supplied `graph` is a valid chain graph. + + References + ---------- + .. [1] Frydenberg, Morten. “The Chain Graph Markov Property.” Scandinavian Journal of + Statistics, vol. 17, no. 4, 1990, pp. 333–53. JSTOR, http://www.jstor.org/stable/4616181. + Accessed 15 Apr. 2023. + + + """ + + # Check if directed edges are acyclic + undirected_edge_name = graph.undirected_edge_name + directed_edge_name = graph.directed_edge_name + visited = set() + all_nodes = graph.nodes() + G_undirected = graph.get_graphs(edge_type=undirected_edge_name) + G_directed = graph.get_graphs(edge_type=directed_edge_name) + # TODO: keep track of paths as first class in queue + for v in all_nodes: + print("v:", v) + seen = {v} + queue = deque([z for _, z in G_directed.out_edges(nbunch=v)]) + if v in visited: + + continue + while queue: + print(queue) + x = queue.popleft() + print("pop", x) + print("seen", seen) + if x in seen: + print("appeared in seen", x) + return False + + seen.add(x) + + for _, node in G_directed.out_edges(nbunch=x): + print("add out edge", node) + queue.append(node) + for nbr in G_undirected.neighbors(x): + print("add nbr edge", nbr) + queue.append(nbr) + + visited.add(v) + + return True diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index a41117905..b98caaf5d 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -333,48 +333,3 @@ def _single_shortest_path_early_stop(G, firstlevel, paths, cutoff, join, valid_p nextlevel[w] = 1 level += 1 return paths - - -def is_valid_cg(G: DiUnGraph) -> bool: - """Check G is a valid Chain Graph (CG). - - A valid CG is one where no undirected edge can be oriented to form - a directed cycle. - - Parameters - ---------- - G : DiUnGraph - The graph with directed and undirected edges. - - Returns - ------- - bool - Whether G is a valid CG. - - Notes - ----- - """ - pass - - -def is_valid_cpdag(G: DiUnGraph) -> bool: - """Check G is a valid CPDAG. - - A valid CPDAG is one where each pair of nodes have - at most one edge between them and the internal graph of directed edges - do not form cycles. - - Parameters - ---------- - G : DiUnGraph - The graph with directed and undirected edges. - - Returns - ------- - bool - Whether G is a valid CPDAG. - - Notes - ----- - """ - pass diff --git a/pywhy_graphs/algorithms/tests/test_cg.py b/pywhy_graphs/algorithms/tests/test_cg.py new file mode 100644 index 000000000..d78dac928 --- /dev/null +++ b/pywhy_graphs/algorithms/tests/test_cg.py @@ -0,0 +1,108 @@ +from pywhy_graphs import CG +from pywhy_graphs.algorithms import is_valid_cg +import pytest + + +@pytest.fixture +def cg_simple_partially_directed_cycle(): + graph = CG() + graph.add_nodes_from(["A", "B", "C", "D"]) + graph.add_edge("A", "B", graph.directed_edge_name) + graph.add_edge("D", "C", graph.directed_edge_name) + graph.add_edge("B", "D", graph.undirected_edge_name) + graph.add_edge("A", "C", graph.undirected_edge_name) + + return graph + + +@pytest.fixture +def cg_multiple_blocks_partially_directed_cycle(): + + graph = CG() + graph.add_nodes_from(["A", "B", "C", "D", "E", "F", "G"]) + graph.add_edge("A", "B", graph.directed_edge_name) + graph.add_edge("D", "C", graph.directed_edge_name) + graph.add_edge("B", "D", graph.undirected_edge_name) + graph.add_edge("A", "C", graph.undirected_edge_name) + graph.add_edge("E", "F", graph.undirected_edge_name) + graph.add_edge("F", "G", graph.undirected_edge_name) + graph.add_edge("G", "E", graph.undirected_edge_name) + + return graph + + +@pytest.fixture +def square_graph(): + graph = CG() + graph.add_nodes_from(["A", "B", "C", "D"]) + graph.add_edge("A", "B", graph.undirected_edge_name) + graph.add_edge("B", "C", graph.undirected_edge_name) + graph.add_edge("C", "D", graph.undirected_edge_name) + graph.add_edge("C", "A", graph.undirected_edge_name) + + return graph + + +@pytest.fixture +def fig_g1_frydenberg(): + graph = CG() + graph.add_nodes_from(["a", "b", "g", "m", "d"]) + graph.add_edge("a", "b", graph.undirected_edge_name) + graph.add_edge("b", "g", graph.directed_edge_name) + graph.add_edge("g", "d", graph.undirected_edge_name) + graph.add_edge("d", "m", graph.undirected_edge_name) + graph.add_edge("a", "m", graph.directed_edge_name) + + return graph + + +@pytest.fixture +def fig_g2_frydenberg(): + graph = CG() + graph.add_nodes_from(["b", "g", "d", "m", "a"]) + graph.add_edge("a", "m", graph.directed_edge_name) + graph.add_edge("m", "g", graph.undirected_edge_name) + graph.add_edge("m", "d", graph.directed_edge_name) + graph.add_edge("g", "d", graph.directed_edge_name) + graph.add_edge("b", "g", graph.directed_edge_name) + + return graph + + +@pytest.fixture +def fig_g3_frydenberg(): + graph = CG() + graph.add_nodes_from(["a", "b", "g"]) + graph.add_edge("b", "a", graph.undirected_edge_name) + graph.add_edge("a", "g", graph.undirected_edge_name) + graph.add_edge("b", "g", graph.directed_edge_name) + + return graph + + +@pytest.mark.parametrize( + "G", + [ + "cg_simple_partially_directed_cycle", + "cg_multiple_blocks_partially_directed_cycle", + "fig_g3_frydenberg", + ], +) +def test_graphs_are_not_valid_cg(G, request): + graph = request.getfixturevalue(G) + + assert not is_valid_cg(graph) + + +@pytest.mark.parametrize( + "G", + [ + "square_graph", + "fig_g1_frydenberg", + "fig_g2_frydenberg", + ], +) +def test_graphs_are_valid_cg(G, request): + graph = request.getfixturevalue(G) + + assert is_valid_cg(graph) diff --git a/pywhy_graphs/classes/diungraph.py b/pywhy_graphs/classes/diungraph.py index 1ebc2fc86..28a90c477 100644 --- a/pywhy_graphs/classes/diungraph.py +++ b/pywhy_graphs/classes/diungraph.py @@ -9,7 +9,34 @@ class DiUnGraph(pywhy_nx.MixedEdgeGraph, AncestralMixin): - """ """ + """ + Private class that represents an abstract MixedEdgeGraph with + only directed and undirected edges. + + This class is not intended for public use, and exists to reduce + duplication of code. + + Parameters + ---------- + incoming_directed_edges : input directed edges (optional, default: None) + Data to initialize directed edges. All arguments that are accepted + by `networkx.DiGraph` are accepted. + incoming_undirected_edges : input undirected edges (optional, default: None) + Data to initialize undirected edges. All arguments that are accepted + by `networkx.Graph` are accepted. + directed_edge_name : str + The name for the directed edges. By default 'directed'. + undirected_edge_name : str + The name for the directed edges. By default 'undirected'. + attr : keyword arguments, optional (default= no attributes) + Attributes to add to graph as key=value pairs. + + See also + -------- + + pywhy_graphs.CG + pywhy_graphs.CPDAG + """ def __init__( self, From d7f71b1f4daf2355db3d05efd2cf2f5404d46885 Mon Sep 17 00:00:00 2001 From: Jaron Lee Date: Sun, 16 Apr 2023 16:15:01 -0400 Subject: [PATCH 5/7] fix chain graph validity function Signed-off-by: Jaron Lee --- pywhy_graphs/algorithms/__init__.py | 2 +- pywhy_graphs/algorithms/cg.py | 67 ++++++++++++++---------- pywhy_graphs/algorithms/tests/test_cg.py | 18 ++++++- 3 files changed, 56 insertions(+), 31 deletions(-) diff --git a/pywhy_graphs/algorithms/__init__.py b/pywhy_graphs/algorithms/__init__.py index 22027575c..3043e89d7 100644 --- a/pywhy_graphs/algorithms/__init__.py +++ b/pywhy_graphs/algorithms/__init__.py @@ -1,4 +1,4 @@ +from .cg import * # noqa: F403 from .cyclic import * # noqa: F403 from .generic import * # noqa: F403 from .pag import * # noqa: F403 -from .cg import * # noqa: F403 diff --git a/pywhy_graphs/algorithms/cg.py b/pywhy_graphs/algorithms/cg.py index a22f979b6..490297b8b 100644 --- a/pywhy_graphs/algorithms/cg.py +++ b/pywhy_graphs/algorithms/cg.py @@ -1,7 +1,5 @@ -from collections import deque - -import networkx as nx -import numpy as np +import copy +from collections import OrderedDict, deque from pywhy_graphs import CG @@ -42,36 +40,47 @@ def is_valid_cg(graph: CG): # Check if directed edges are acyclic undirected_edge_name = graph.undirected_edge_name directed_edge_name = graph.directed_edge_name - visited = set() all_nodes = graph.nodes() G_undirected = graph.get_graphs(edge_type=undirected_edge_name) G_directed = graph.get_graphs(edge_type=directed_edge_name) - # TODO: keep track of paths as first class in queue + + # Search over all nodes. for v in all_nodes: - print("v:", v) - seen = {v} - queue = deque([z for _, z in G_directed.out_edges(nbunch=v)]) - if v in visited: + queue = deque([]) + # Fill queue with paths from v starting with outgoing directed edge + # OrderedDict used for O(1) set membership and ordering + for _, z in G_directed.out_edges(nbunch=v): + d = OrderedDict() + d[v] = None + d[z] = None + queue.append(d) - continue while queue: - print(queue) - x = queue.popleft() - print("pop", x) - print("seen", seen) - if x in seen: - print("appeared in seen", x) - return False - - seen.add(x) - - for _, node in G_directed.out_edges(nbunch=x): - print("add out edge", node) - queue.append(node) - for nbr in G_undirected.neighbors(x): - print("add nbr edge", nbr) - queue.append(nbr) - - visited.add(v) + # For each path in queue, progress along edges in certain + # manner + path = queue.popleft() + rev_path = reversed(path) + last_added = next(rev_path) + second_last_added = next(rev_path) + + # For directed edges progress is allowed for outgoing edges + # only + for _, node in G_directed.out_edges(nbunch=last_added): + if node in path: + return False + new_path = copy.deepcopy(path) + new_path[node] = None + queue.append(new_path) + + # For undirected edges, progress is allowed for neighbors + # which were not visited. E.g. if the path is currently A - B, + # do not consider adding A when iterating over neighbors of B. + for node in G_undirected.neighbors(last_added): + if node != second_last_added: + if node in path: + return False + new_path = copy.deepcopy(path) + new_path[node] = None + queue.append(new_path) return True diff --git a/pywhy_graphs/algorithms/tests/test_cg.py b/pywhy_graphs/algorithms/tests/test_cg.py index d78dac928..479348dfe 100644 --- a/pywhy_graphs/algorithms/tests/test_cg.py +++ b/pywhy_graphs/algorithms/tests/test_cg.py @@ -1,6 +1,7 @@ +import pytest + from pywhy_graphs import CG from pywhy_graphs.algorithms import is_valid_cg -import pytest @pytest.fixture @@ -80,12 +81,27 @@ def fig_g3_frydenberg(): return graph +@pytest.fixture +def fig_g4_frydenberg(): + graph = CG() + graph.add_nodes_from(["b", "g", "d", "m", "a"]) + graph.add_edge("b", "g", graph.directed_edge_name) + graph.add_edge("a", "b", graph.undirected_edge_name) + graph.add_edge("g", "d", graph.undirected_edge_name) + graph.add_edge("d", "m", graph.undirected_edge_name) + graph.add_edge("m", "a", graph.undirected_edge_name) + graph.add_edge("a", "g", graph.directed_edge_name) + + return graph + + @pytest.mark.parametrize( "G", [ "cg_simple_partially_directed_cycle", "cg_multiple_blocks_partially_directed_cycle", "fig_g3_frydenberg", + "fig_g4_frydenberg", ], ) def test_graphs_are_not_valid_cg(G, request): From 79f74571d6ff45144169142f002ac4fb0093ed46 Mon Sep 17 00:00:00 2001 From: Jaron Lee Date: Sun, 16 Apr 2023 16:40:20 -0400 Subject: [PATCH 6/7] update changelog Signed-off-by: Jaron Lee --- docs/whats_new/v0.1.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/whats_new/v0.1.rst b/docs/whats_new/v0.1.rst index bb87bb11f..71c1f40d2 100644 --- a/docs/whats_new/v0.1.rst +++ b/docs/whats_new/v0.1.rst @@ -25,6 +25,7 @@ Version 0.1 Changelog --------- +- |Feature|| Introduce chain graphs and validity checks, and refactor CPDAG and chain graphs to use a directed-undirected private class, by `Jaron Lee`_ (:pr:`73`) - |Feature| Add keyword argument for graph labels in :func:`pywhy_graphs.viz.draw`, by `Aryan Roy`_ (:pr:`71`) - |Feature| Implement minimal m-separator function in :func:`pywhy_graphs.networkx.minimal_m_separator` with a BFS approach, by `Jaron Lee`_ (:pr:`53`) - |Feature| Implement m-separation :func:`pywhy_graphs.networkx.m_separated` with the BallTree approach, by `Jaron Lee`_ (:pr:`48`) From 7d69f7f3f10deec4d5e3d0ba159fdd259c93d587 Mon Sep 17 00:00:00 2001 From: Jaron Lee Date: Sun, 16 Apr 2023 17:52:47 -0400 Subject: [PATCH 7/7] fix spelling errors Signed-off-by: Jaron Lee --- docs/whats_new/v0.1.rst | 2 +- pywhy_graphs/algorithms/cg.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/whats_new/v0.1.rst b/docs/whats_new/v0.1.rst index 71c1f40d2..f332b726a 100644 --- a/docs/whats_new/v0.1.rst +++ b/docs/whats_new/v0.1.rst @@ -25,7 +25,7 @@ Version 0.1 Changelog --------- -- |Feature|| Introduce chain graphs and validity checks, and refactor CPDAG and chain graphs to use a directed-undirected private class, by `Jaron Lee`_ (:pr:`73`) +- |Feature| Introduce chain graphs and validity checks, and refactor CPDAG and chain graphs to use a directed-undirected private class, by `Jaron Lee`_ (:pr:`73`) - |Feature| Add keyword argument for graph labels in :func:`pywhy_graphs.viz.draw`, by `Aryan Roy`_ (:pr:`71`) - |Feature| Implement minimal m-separator function in :func:`pywhy_graphs.networkx.minimal_m_separator` with a BFS approach, by `Jaron Lee`_ (:pr:`53`) - |Feature| Implement m-separation :func:`pywhy_graphs.networkx.m_separated` with the BallTree approach, by `Jaron Lee`_ (:pr:`48`) diff --git a/pywhy_graphs/algorithms/cg.py b/pywhy_graphs/algorithms/cg.py index 490297b8b..28a4d7241 100644 --- a/pywhy_graphs/algorithms/cg.py +++ b/pywhy_graphs/algorithms/cg.py @@ -10,7 +10,7 @@ def is_valid_cg(graph: CG): """ Checks if a supplied chain graph is valid. - This implements the original defintion of a (Lauritzen Wermuth Frydenberg) chain graph as + This implements the original definition of a (Lauritzen Wermuth Frydenberg) chain graph as presented in [1]_. Define a cycle as a series of nodes X_1 -o X_2 ... X_n -o X_1 where the edges may be directed or