Skip to content

Commit

Permalink
Removing context initially where I can
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Jul 18, 2024
1 parent 190c970 commit db31a80
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 42 deletions.
47 changes: 24 additions & 23 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,6 @@ for classes (``CamelCase`` names) and functions
stage.


Context for causal discovery
============================

Rather than just data, in many cases structure learning
has additional "context", in the form of apriori knowledge of
the structure, or additional datasets from different environments.
All structure learning algorithms in ``dodiscover`` pass in a ``Context``
object rather than just data. One should use our builder ``make_context``
API for construction of the Context class.

See docs for ``Context`` and ``make_context`` for more information.

.. currentmodule:: dodiscover
.. autosummary::
:toctree: generated/

make_context
ContextBuilder
InterventionalContextBuilder
context.Context


Constraint-based structure learning
===================================

Expand Down Expand Up @@ -172,4 +150,27 @@ independence.
:toctree: generated/

generate_knn_in_subspace
restricted_nbr_permutation
restricted_nbr_permutation

**The following API is for internal development and is completely experimental.**

Context for causal discovery
============================

Rather than just data, in many cases structure learning
has additional "context", in the form of apriori knowledge of
the structure, or additional datasets from different environments.
All structure learning algorithms in ``dodiscover`` pass in a ``Context``
object rather than just data. One should use our builder ``make_context``
API for construction of the Context class.

See docs for ``Context`` and ``make_context`` for more information.

.. currentmodule:: dodiscover
.. autosummary::
:toctree: generated/

make_context
ContextBuilder
InterventionalContextBuilder
context.Context
21 changes: 16 additions & 5 deletions dodiscover/constraint/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _initialize_sep_sets(self, init_graph: nx.Graph) -> SeparatingSet:
sep_set: SeparatingSet = defaultdict(lambda: defaultdict(list))

# since we are not starting from a complete graph, find the separating sets
for (node_i, node_j) in itertools.combinations(init_graph.nodes, 2):
for node_i, node_j in itertools.combinations(init_graph.nodes, 2):
if not init_graph.has_edge(node_i, node_j):
sep_set[node_i][node_j] = []
sep_set[node_j][node_i] = []
Expand Down Expand Up @@ -184,7 +184,11 @@ def orient_edges(self, graph: EquivalenceClass) -> None:
"skeleton graph given a separating set."
)

def learn_graph(self, data: pd.DataFrame, context: Context):
def learn_graph(
self,
data: pd.DataFrame,
context: Context = None,
):
"""Fit constraint-based discovery algorithm on dataset 'X'.
Parameters
Expand All @@ -208,6 +212,12 @@ def learn_graph(self, data: pd.DataFrame, context: Context):
Control over the constraints imposed by the algorithm can be passed into the class
constructor.
"""
if context is None:
# make a private Context object to store causal context used in this algorithm
# store the context
from dodiscover.context_builder import make_context

context = make_context().build()
self.context_ = context.copy()

# initialize graph object to apply learning
Expand Down Expand Up @@ -263,8 +273,9 @@ def evaluate_edge(
def learn_skeleton(
self,
data: pd.DataFrame,
context: Context,
context: Context = None,
sep_set: Optional[SeparatingSet] = None,
**params,
) -> Tuple[nx.Graph, SeparatingSet]:
"""Learns the skeleton of a causal DAG using pairwise (conditional) independence testing.
Expand All @@ -274,10 +285,10 @@ def learn_skeleton(
----------
data : pd.DataFrame
The dataset.
context : Context
A context object.
sep_set : dict of dict of list of set
The separating set.
params : dict
Additional parameters to pass to the method.
Returns
-------
Expand Down
17 changes: 14 additions & 3 deletions dodiscover/constraint/fcialg.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def _apply_rule10(
# that:
# i) begin the uncovered pd path and
# ii) are distinct (done by construction) here
for (m, w) in combinations(graph.neighbors(a), 2): # type: ignore
for m, w in combinations(graph.neighbors(a), 2): # type: ignore
if m == c or w == c:
continue

Expand Down Expand Up @@ -763,7 +763,7 @@ def _apply_orientation_rules(self, graph: EquivalenceClass, sep_set: SeparatingS
logger.info(f"Running R1-10 for iteration {idx}")

for u in graph.nodes:
for (a, c) in permutations(graph.neighbors(u), 2):
for a, c in permutations(graph.neighbors(u), 2):
logger.debug(f"Check {u} {a} {c}")

# apply R1-3 to orient triples and arrowheads
Expand Down Expand Up @@ -821,8 +821,19 @@ def _apply_orientation_rules(self, graph: EquivalenceClass, sep_set: SeparatingS
idx += 1

def learn_skeleton(
self, data: pd.DataFrame, context: Context, sep_set: Optional[SeparatingSet] = None
self,
data: pd.DataFrame,
context: Context = None,
sep_set: Optional[SeparatingSet] = None,
**params,
) -> Tuple[nx.Graph, SeparatingSet]:
if context is None:
# make a private Context object to store causal context used in this algorithm
# store the context
from dodiscover.context_builder import make_context

context = make_context().build()

# now compute all possibly d-separating sets and learn a better skeleton
skel_alg = LearnSemiMarkovianSkeleton(
self.ci_estimator,
Expand Down
22 changes: 20 additions & 2 deletions dodiscover/constraint/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,19 @@ def __init__(
self.known_intervention_targets = known_intervention_targets

def learn_skeleton(
self, data: pd.DataFrame, context: Context, sep_set: Optional[SeparatingSet] = None
self,
data: pd.DataFrame,
context: Context = None,
sep_set: Optional[SeparatingSet] = None,
**params,
) -> Tuple[nx.Graph, SeparatingSet]:
if context is None:
# make a private Context object to store causal context used in this algorithm
# store the context
from dodiscover.context_builder import make_context

context = make_context().build()

# now compute all possibly d-separating sets and learn a better skeleton
self.skeleton_learner_ = LearnInterventionSkeleton(
self.ci_estimator,
Expand All @@ -157,7 +168,7 @@ def learn_skeleton(
self.n_ci_tests += self.skeleton_learner_.n_ci_tests
return skel_graph, sep_set

def learn_graph(self, data: List[pd.DataFrame], context: Context):
def learn_graph(self, data: List[pd.DataFrame], context: Context = None):
"""Learn the relevant causal graph equivalence class.
From the pairs of datasets, we take all combinations and
Expand All @@ -177,6 +188,13 @@ def learn_graph(self, data: List[pd.DataFrame], context: Context):
self : PsiFCI
The fitted learner.
"""
if context is None:
# make a private Context object to store causal context used in this algorithm
# store the context
from dodiscover.context_builder import make_context

context = make_context().build()

if not isinstance(data, list):
raise TypeError("The input datasets must be in a Python list.")

Expand Down
13 changes: 12 additions & 1 deletion dodiscover/constraint/pcalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ def convert_skeleton_graph(self, graph: nx.Graph) -> EquivalenceClass:
return graph

def learn_skeleton(
self, data: pd.DataFrame, context: Context, sep_set: Optional[SeparatingSet] = None
self,
data: pd.DataFrame,
context: Context = None,
sep_set: Optional[SeparatingSet] = None,
**params,
) -> Tuple[nx.Graph, SeparatingSet]:
"""Learns the skeleton of a causal DAG using pairwise (conditional) independence testing.
Expand All @@ -162,6 +166,13 @@ def learn_skeleton(
to determine which variables are (in)dependent. This specific algorithm
compares exhaustively pairs of adjacent variables.
"""
if context is None:
# make a private Context object to store causal context used in this algorithm
# store the context
from dodiscover.context_builder import make_context

context = make_context().build()

skel_alg = LearnSkeleton(
self.ci_estimator,
sep_set=sep_set,
Expand Down
27 changes: 24 additions & 3 deletions dodiscover/constraint/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,14 @@ def _initialize_params(self, context) -> Context:
nx.set_edge_attributes(context.init_graph, -1e-5, "pvalue")
return context

def learn_graph(self, data: pd.DataFrame, context: Context, check_input: bool = True):
def learn_graph(self, data: pd.DataFrame, context: Context = None, check_input: bool = True):
if context is None:
# make a private Context object to store causal context used in this algorithm
# store the context
from dodiscover.context_builder import make_context

context = make_context().build()

if check_input:
# initialize learning parameters
context = self._initialize_params(context)
Expand Down Expand Up @@ -1076,7 +1083,14 @@ def _initialize_params(self, context) -> Context:

return super()._initialize_params(context)

def learn_graph(self, data: pd.DataFrame, context: Context, check_input: bool = True):
def learn_graph(self, data: pd.DataFrame, context: Context = None, check_input: bool = True):
if context is None:
# make a private Context object to store causal context used in this algorithm
# store the context
from dodiscover.context_builder import make_context

context = make_context().build()

if check_input:
context = self._initialize_params(context)

Expand Down Expand Up @@ -1214,8 +1228,15 @@ def __init__(
self.known_intervention_targets = known_intervention_targets

def learn_graph(
self, data: List[pd.DataFrame], context: Context, check_input: bool = True
self, data: List[pd.DataFrame], context: Context = None, check_input: bool = True
) -> None:
if context is None:
# make a private Context object to store causal context used in this algorithm
# store the context
from dodiscover.context_builder import make_context

context = make_context().build()

# ensure data is a list
if isinstance(data, pd.DataFrame):
data = [data]
Expand Down
9 changes: 8 additions & 1 deletion dodiscover/replearning/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, ci_estimator_method: str = "kci", alpha: float = 0.05):
# go in a base class too.
self.causal_learn_graph_ = None

def learn_graph(self, data: DataFrame, context: DataFrame):
def learn_graph(self, data: DataFrame, context: DataFrame = None):
"""Fit the GIN model to data.
Currently the context object is not used.
Expand All @@ -88,6 +88,13 @@ def learn_graph(self, data: DataFrame, context: DataFrame):
self : GIN
The fitted GIN object.
"""
if context is None:
# make a private Context object to store causal context used in this algorithm
# store the context
from dodiscover.context_builder import make_context

context = make_context().build()

from causallearn.search.HiddenCausal.GIN.GIN import GIN as GIN_

causal_learn_graph, _ = GIN_(data.to_numpy(), self.ci_estimator_method, self.alpha)
Expand Down
11 changes: 9 additions & 2 deletions dodiscover/toporder/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ class TopOrderInterface(metaclass=ABCMeta):
"""

@abstractmethod
def learn_graph(self, data: pd.DataFrame, context: Context) -> None:
def learn_graph(self, data: pd.DataFrame, context: Context = None) -> None:
raise NotImplementedError()

@abstractmethod
Expand Down Expand Up @@ -585,7 +585,7 @@ def _get_leaf(self, leaf: int, remaining_nodes: List[int], current_order: List[i
k += 1
return leaf

def learn_graph(self, data_df: pd.DataFrame, context: Context) -> None:
def learn_graph(self, data_df: pd.DataFrame, context: Context = None) -> None:
"""
Fit topological order based causal discovery algorithm on input data.
Expand All @@ -596,6 +596,13 @@ def learn_graph(self, data_df: pd.DataFrame, context: Context) -> None:
context: Context
The context of the causal discovery problem.
"""
if context is None:
# make a private Context object to store causal context used in this algorithm
# store the context
from dodiscover.context_builder import make_context

context = make_context().build()

X = data_df.to_numpy()
self.context = context

Expand Down
9 changes: 7 additions & 2 deletions examples/plot_pc_alg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
In this example, we will introduce the main abstractions and concepts used in
dodiscover for causal discovery:
- learner: Any causal discovery algorithm that has a similar scikit-learn API.
- context: Causal assumptions.
- learner: Any causal discovery algorithm that has a similar scikit-learn API,
but uses the ``learn_graph`` API to apply an algorithm to data.
.. currentmodule:: dodiscover
"""
Expand Down Expand Up @@ -178,6 +178,11 @@ def clone(self):
pc = PC(ci_estimator=ci_estimator)
pc.learn_graph(data, context)

# .. note:: You can also just neglect to pass in a ``context`` object, and the
# algorithm will infer the context from the data. This is not recommended
# as it is always better to specify the context apriori before running a
# causal discovery algorithm.

# The resulting completely partially directed acyclic graph (CPDAG) that is learned
# is a "Markov equivalence class", which encodes all the conditional dependences that
# were learned from the data. Note here, because the CI test fails to find the
Expand Down

0 comments on commit db31a80

Please sign in to comment.