From 053a58befa774e5eedcfdca58d8889a96cf79d33 Mon Sep 17 00:00:00 2001 From: Aryan Roy <50577809+aryan26roy@users.noreply.github.com> Date: Mon, 26 Aug 2024 18:44:38 +0530 Subject: [PATCH] [ENH] Add the ability to find Proper Possibly Directed Paths (#112) * Added function to return a list of possibly directed paths between two nodes --------- Signed-off-by: Aryan Roy Co-authored-by: Adam Li --- doc/whats_new/v0.2.rst | 1 + pywhy_graphs/algorithms/generic.py | 186 ++++++++++++++++++ pywhy_graphs/algorithms/tests/test_generic.py | 181 ++++++++++++++++- 3 files changed, 367 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v0.2.rst b/doc/whats_new/v0.2.rst index aebb2d19e..92fe1043d 100644 --- a/doc/whats_new/v0.2.rst +++ b/doc/whats_new/v0.2.rst @@ -33,6 +33,7 @@ Changelog - |Feature| Implement functions for converting between a DAG and PDAG and CPDAG for generating consistent extensions of a CPDAG for example. These functions are :func:`pywhy_graphs.algorithms.pdag_to_cpdag`, :func:`pywhy_graphs.algorithms.pdag_to_dag` and :func:`pywhy_graphs.algorithms.dag_to_cpdag`, by `Adam Li`_ (:pr:`102`) - |API| Remove poetry based setup, by `Adam Li`_ (:pr:`110`) - |Feature| Implement and test function to validate PAG, by `Aryan Roy`_ (:pr:`100`) +- |Feature| Implement and test function to find all the proper possibly directed paths, by `Aryan Roy`_ (:pr:`112`) Code and Documentation Contributors ----------------------------------- diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index f0011c655..5166c057c 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -19,6 +19,7 @@ "dag_to_mag", "is_maximal", "all_vstructures", + "proper_possibly_directed_path", ] @@ -855,3 +856,188 @@ def all_vstructures(G: nx.DiGraph, as_edges: bool = False): else: vstructs.add((p1, node, p2)) # type: ignore return vstructs + + +def _check_back_arrow(G: ADMG, X, Y: set): + """Retrieve all the neigbors of X that do not have + an arrow pointing back to it. + + Parameters + ---------- + G : DiGraph + A directed graph. + X : Node + Y : Set + A set of neigbors of X. + + Returns + ------- + out : set + A set of all the neighbors of X that do not have an arrow pointing + back to it. + """ + out = set() + + for elem in Y: + if not ( + G.has_edge(X, elem, G.bidirected_edge_name) or G.has_edge(elem, X, G.directed_edge_name) + ): + out.update(elem) + + return out + + +def _get_neighbors_of_set(G, X: set): + """Retrieve all the neigbors of X when X has more than one element. + + Note that if X is not a set, graph.neighbors(X) is sufficient. + + Parameters + ---------- + G : DiGraph + A directed graph. + X : Set + + Returns + ------- + out : set + A set of all the neighbors of X. + """ + + out = set() + + for elem in X: + elem_neighbors = set(G.neighbors(elem)) + elem_possible_neighbors = _check_back_arrow(G, elem, elem_neighbors) + to_remove = X.intersection(elem_possible_neighbors) + elem_neighbors = elem_possible_neighbors - to_remove + + if len(elem_neighbors) != 0: + for nbh in elem_neighbors: + temp = (elem,) + temp = temp + (nbh,) + out.add(temp) + return out + + +def _recursively_find_pd_paths(G, X, paths, Y): + """Recursively finds all the possibly directed paths for a given + graph. + + Parameters + ---------- + G : DiGraph + A directed graph. + X : Set + Source. + paths : Set + Set of initial paths from X. + Y : Set + Destination + + Returns + ------- + out : set + A set of all the possibly directed paths. + """ + + counter = 0 + new_paths = set() + + for elem in paths: + cur_elem = elem[-1] + + if cur_elem in Y: + new_paths.add(elem) + continue + + nbr_temp = G.neighbors(cur_elem) + nbr_possible = _check_back_arrow(G, cur_elem, nbr_temp) + + if len(nbr_possible) == 0: + new_paths = new_paths + (elem,) + + possible_end = nbr_possible.intersection(Y) + + if len(possible_end) != 0: + for nbr in possible_end: + temp_path = elem + temp_path = temp_path + (nbr,) + new_paths.add(temp_path) + + remaining_nodes = nbr_possible - possible_end + remaining_nodes = ( + remaining_nodes + - remaining_nodes.intersection(set(elem)) + - remaining_nodes.intersection(X) + ) + + temp_set = set() + for nbr in remaining_nodes: + temp_paths = elem + temp_paths = temp_paths + (nbr,) + temp_set.add(temp_paths) + + new_paths.update(_recursively_find_pd_paths(G, X, temp_set, Y)) + + return new_paths + + +def proper_possibly_directed_path(G, X: Optional[Set], Y: Optional[Set]): + """Find all the proper possibly directed paths in a graph. A proper possibly directed + path from X to Y is a set of edges with just the first node in X and none of the edges + with an arrow pointing back to X. + + Parameters + ---------- + G : DiGraph + A directed graph. + X : Set + Source. + Y : Set + Destination + + Returns + ------- + out : set + A set of all the proper possibly directed paths. + + Examples + -------- + The function generates a set of tuples containing all the valid + proper possibly directed paths from X to Y. + + >>> import pywhy_graphs + >>> from pywhy_graphs import PAG + >>> pag = PAG() + >>> pag.add_edge("A", "G", pag.directed_edge_name) + >>> pag.add_edge("G", "C", pag.directed_edge_name) + >>> pag.add_edge("C", "H", pag.directed_edge_name) + >>> pag.add_edge("Z", "C", pag.circle_edge_name) + >>> pag.add_edge("C", "Z", pag.circle_edge_name) + >>> pag.add_edge("Y", "X", pag.directed_edge_name) + >>> pag.add_edge("X", "Z", pag.directed_edge_name) + >>> pag.add_edge("Z", "K", pag.directed_edge_name) + >>> Y = {"H", "K"} + >>> X = {"Y", "A"} + >>> pywhy_graphs.proper_possibly_directed_path(pag, X, Y) + {('A', 'G', 'C', 'H'), ('Y', 'X', 'Z', 'C', 'H'), ('Y', 'X', 'Z', 'K'), ('A', 'G', 'C', 'Z', 'K')} + + """ + + if isinstance(X, set): + x_neighbors = _get_neighbors_of_set(G, X) + else: + nbr_temp = G.neighbors(X) + nbr_possible = _check_back_arrow(nbr_temp) + x_neighbors = [] + + for elem in nbr_possible: + temp = dict() + temp[0] = X + temp[1] = elem + x_neighbors.append(temp) + + path_list = _recursively_find_pd_paths(G, X, x_neighbors, Y) + + return path_list diff --git a/pywhy_graphs/algorithms/tests/test_generic.py b/pywhy_graphs/algorithms/tests/test_generic.py index 09218a334..e3c7b2876 100644 --- a/pywhy_graphs/algorithms/tests/test_generic.py +++ b/pywhy_graphs/algorithms/tests/test_generic.py @@ -2,7 +2,7 @@ import pytest import pywhy_graphs -from pywhy_graphs import ADMG +from pywhy_graphs import ADMG, PAG from pywhy_graphs.algorithms import all_vstructures @@ -496,3 +496,182 @@ def test_all_vstructures(): # Assert that the returned values are as expected assert len(v_structs_edges) == 0 assert len(v_structs_tuples) == 0 + + +def test_proper_possibly_directed(): + # X <- Y <-> Z <-> H; Z -> X + + admg = ADMG() + admg.add_edge("Y", "X", admg.directed_edge_name) + admg.add_edge("X", "Z", admg.directed_edge_name) + admg.add_edge("Z", "H", admg.directed_edge_name) + + Y = {"H"} + X = {"Y"} + + correct = {("Y", "X", "Z", "H")} + out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y) + assert correct == out + + admg = ADMG() + admg.add_edge("A", "X", admg.directed_edge_name) + admg.add_edge("Y", "X", admg.directed_edge_name) + admg.add_edge("X", "Z", admg.directed_edge_name) + admg.add_edge("Z", "H", admg.directed_edge_name) + + Y = {"H"} + X = {"Y", "A"} + + correct = {("Y", "X", "Z", "H"), ("A", "X", "Z", "H")} + out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y) + assert correct == out + + admg = ADMG() + admg.add_edge("X", "A", admg.directed_edge_name) + admg.add_edge("Y", "X", admg.directed_edge_name) + admg.add_edge("X", "Z", admg.directed_edge_name) + admg.add_edge("Z", "H", admg.directed_edge_name) + + Y = {"H"} + X = {"Y", "A"} + + correct = {("Y", "X", "Z", "H")} + out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y) + assert correct == out + + admg = ADMG() + admg.add_edge("X", "A", admg.directed_edge_name) + admg.add_edge("Y", "X", admg.directed_edge_name) + admg.add_edge("X", "Z", admg.directed_edge_name) + admg.add_edge("Z", "H", admg.directed_edge_name) + admg.add_edge("K", "Z", admg.directed_edge_name) + + Y = {"H", "K"} + X = {"Y", "A"} + + correct = {("Y", "X", "Z", "H")} + out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y) + assert correct == out + + admg = ADMG() + admg.add_edge("A", "X", admg.directed_edge_name) + admg.add_edge("Y", "X", admg.directed_edge_name) + admg.add_edge("X", "Z", admg.directed_edge_name) + admg.add_edge("Z", "H", admg.directed_edge_name) + admg.add_edge("Z", "K", admg.directed_edge_name) + + Y = {"H", "K"} + X = {"Y", "A"} + + correct = { + ("Y", "X", "Z", "K"), + ("A", "X", "Z", "K"), + ("Y", "X", "Z", "H"), + ("A", "X", "Z", "H"), + } + out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y) + assert correct == out + + admg = ADMG() + admg.add_edge("A", "G", admg.directed_edge_name) + admg.add_edge("G", "C", admg.directed_edge_name) + admg.add_edge("C", "H", admg.directed_edge_name) + admg.add_edge("Y", "X", admg.directed_edge_name) + admg.add_edge("X", "Z", admg.directed_edge_name) + admg.add_edge("Z", "K", admg.directed_edge_name) + + Y = {"H", "K"} + X = {"Y", "A"} + + correct = {("Y", "X", "Z", "K"), ("A", "G", "C", "H")} + out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y) + assert correct == out + + admg = ADMG() + admg.add_edge("A", "G", admg.directed_edge_name) + admg.add_edge("G", "C", admg.directed_edge_name) + admg.add_edge("C", "H", admg.directed_edge_name) + admg.add_edge("Z", "C", admg.directed_edge_name) + admg.add_edge("Y", "X", admg.directed_edge_name) + admg.add_edge("X", "Z", admg.directed_edge_name) + admg.add_edge("Z", "K", admg.directed_edge_name) + + Y = {"H", "K"} + X = {"Y", "A"} + + correct = {("Y", "X", "Z", "K"), ("Y", "X", "Z", "C", "H"), ("A", "G", "C", "H")} + out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y) + assert correct == out + + admg = ADMG() + admg.add_edge("A", "G", admg.directed_edge_name) + admg.add_edge("A", "H", admg.directed_edge_name) + admg.add_edge("K", "G", admg.directed_edge_name) + admg.add_edge("K", "H", admg.directed_edge_name) + + Y = {"G", "H"} + X = {"A", "K"} + + correct = {("K", "H"), ("K", "G"), ("A", "G"), ("A", "H")} + out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y) + assert correct == out + + admg = ADMG() + admg.add_edge("A", "G", admg.directed_edge_name) + admg.add_edge("G", "C", admg.directed_edge_name) + admg.add_edge("C", "H", admg.directed_edge_name) + admg.add_edge("Z", "C", admg.bidirected_edge_name) + admg.add_edge("Y", "X", admg.directed_edge_name) + admg.add_edge("X", "Z", admg.directed_edge_name) + admg.add_edge("Z", "K", admg.directed_edge_name) + + Y = {"H", "K"} + X = {"Y", "A"} + + correct = { + ("A", "G", "C", "H"), + ("Y", "X", "Z", "K"), + } + out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y) + assert correct == out + + admg = ADMG() + admg.add_edge("A", "G", admg.directed_edge_name) + admg.add_edge("G", "C", admg.directed_edge_name) + admg.add_edge("C", "H", admg.directed_edge_name) + admg.add_edge("Z", "C", admg.bidirected_edge_name) + admg.add_edge("Y", "X", admg.directed_edge_name) + admg.add_edge("X", "Z", admg.directed_edge_name) + admg.add_edge("Z", "K", admg.directed_edge_name) + + Y = {"H", "K"} + X = {"Y", "A"} + + correct = {("Y", "X", "Z", "K"), ("A", "G", "C", "H")} + out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y) + assert correct == out + + +def test_ppdp_PAG(): + + pag = PAG() + pag.add_edge("A", "G", pag.directed_edge_name) + pag.add_edge("G", "C", pag.directed_edge_name) + pag.add_edge("C", "H", pag.directed_edge_name) + pag.add_edge("Z", "C", pag.circle_edge_name) + pag.add_edge("C", "Z", pag.circle_edge_name) + pag.add_edge("Y", "X", pag.directed_edge_name) + pag.add_edge("X", "Z", pag.directed_edge_name) + pag.add_edge("Z", "K", pag.directed_edge_name) + + Y = {"H", "K"} + X = {"Y", "A"} + + correct = { + ("Y", "X", "Z", "K"), + ("Y", "X", "Z", "C", "H"), + ("A", "G", "C", "H"), + ("A", "G", "C", "Z", "K"), + } + out = pywhy_graphs.proper_possibly_directed_path(pag, X, Y) + assert correct == out