Skip to content

Commit

Permalink
Simpler Stage representation.
Browse files Browse the repository at this point in the history
  • Loading branch information
felixleopoldo committed Nov 27, 2023
1 parent 3eedc20 commit c32f398
Show file tree
Hide file tree
Showing 12 changed files with 3,537 additions and 2,728 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@



This is a Python package for CStree models :footcite:p:`duarte2021representation`, a family of graphical causal models that encode context-specific information for discrete data.
This is a Python package for CStree models :footcite:p:`duarte2021representation`, a family of graphical causal models that encode context-specific dependence for multivariate multinomial distributions.

.. As not all staged tree models admit this property, CStrees are a subclass that provides a transparent, intuitive and compact representation of context-specific causal information.
Expand Down
1,222 changes: 1,087 additions & 135 deletions docs/source/fig1_demo.ipynb

Large diffs are not rendered by default.

182 changes: 92 additions & 90 deletions docs/source/learn_demo.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/learn_demo_gibbs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,7 @@
"source": [
"pcgraph = pc(x[1:].values, 0.05, \"chisq\", node_names=x.columns)\n",
"poss_cvars = ctl.causallearn_graph_posscvars(pcgraph, labels=x.columns)\n",
"print(\"Possible context varaibles per node:\", poss_cvars)"
"print(\"Possible context vararibles per node:\", poss_cvars)"
]
},
{
Expand Down
1,222 changes: 1,087 additions & 135 deletions notebooks/fig1_demo.ipynb

Large diffs are not rendered by default.

182 changes: 92 additions & 90 deletions notebooks/learn_demo.ipynb

Large diffs are not rendered by default.

2,302 changes: 591 additions & 1,711 deletions notebooks/learn_demo_gibbs_cat.ipynb

Large diffs are not rendered by default.

350 changes: 176 additions & 174 deletions notebooks/learn_demo_mixed_cards.ipynb

Large diffs are not rendered by default.

398 changes: 199 additions & 199 deletions notebooks/learn_demo_v-structure.ipynb

Large diffs are not rendered by default.

282 changes: 137 additions & 145 deletions notebooks/learn_sandbox.ipynb

Large diffs are not rendered by default.

66 changes: 44 additions & 22 deletions src/cstrees/cstree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import logging
import sys
from importlib import reload # Not needed in Python 2
from importlib import reload

reload(logging)
FORMAT = '%(filename)s:%(funcName)s (%(lineno)d): %(message)s'
Expand All @@ -36,12 +36,10 @@ def plot(graph, layout="dot"):


class CStree:
""" A CStree class. The levels are enumerated from 0,...,p-1.
You may provide labels for the lavels, corresponding to the
random variable they represent.
""" A class representing a CStree.
Args:
cards (list): A list of integers representing the cardinality of each level.
cards (list): A list of integers representing the cardinality of each level (indexed from 0).
labels (list, optional): A list of strings representing the labels of each level. Defaults to [0,1,...,p-1].
Example:
Expand Down Expand Up @@ -86,7 +84,7 @@ def stage_proportion(self, stage: st.Stage):
float: A number between 0 and 1.
Example:
# assuming all variables are binary
>>> # Assuming all variables are binary
>>> s = Stage([0, {0, 1}, 1])
>>> cstree.stage_proportion(s)
0.25
Expand Down Expand Up @@ -119,21 +117,45 @@ def update_stages(self, stages: dict):
>>> })
"""
for lev, stage_list in stages.items():
for stage in stage_list:
stage.cards = self.cards#[:lev+1] # Or full cards?


self.stages.update(stages)

stages_to_add = {key: [] for key in stages.keys()}

# If there are dicts, we convert them to Stages.
# Stages are updated to contain a cardinality list, inherited from the CStree.
for lev, list_of_stage_repr in stages.items():
# it can be either a Stage of a dict that should be converted to a stage.

for stage_repr in list_of_stage_repr:
if isinstance(stage_repr, dict):
# If its a dict, we convert it to a stage.
stage_list_repr = []
for l in range(lev+1):
if l in stage_repr["context"]:
stage_list_repr.append(stage_repr["context"][l])
else:
stage_list_repr.append(set(range(self.cards[l])))
# Create a stage from the stage_list_repr
s = st.Stage(stage_list_repr)
if "color" in stage_repr:
s.color = stage_repr["color"]

s.cards = self.cards
stages_to_add[lev].append(s)
else:
# Just add the stage and set the cards
stage_repr.cards = self.cards
stages_to_add[lev].append(stage_repr)

self.stages.update(stages_to_add)
if -1 not in self.stages:
self.stages[-1] = [st.Stage([], color="black")]


def get_stage(self, node: tuple):
""" Get the stage of a node in the cstree.
""" Get the stage of a node in the CStree.
Args:
node (tuple): A node in the CStree.
node (tuple or list): A node in the CStree. It could be e.g. (0, 1, 0, 1).
Example:
>>> # tree is the fig. 1 CStree
>>> stage = tree.get_stage([0, 0])
Expand All @@ -158,10 +180,10 @@ def get_stage(self, node: tuple):
return stage

def to_df(self, write_probs=False):
""" Converts the CStree to a pandas dataframe.
""" Converts the CStree to a Pandas dataframe.
Returns:
df (pd.DataFrame): A pandas dataframe with the stages of the CStree.
df (pd.DataFrame): A Pandas dataframe with the stages of the CStree.
Example:
>>> tree.to_df()
Expand Down Expand Up @@ -438,11 +460,11 @@ def csi_relations(self, level="all"):
>>> rels = tree.csi_relations()
>>> for cont, rels in rels.items():
>>> for rel in rels:
>>> print(rel)
X0X2, X1=0
X1X3, X0=0, X2=0
X1X3, X0=0, X2=1
X1X3, X0=1, X2=0
>>> print(rel)
02 | 1=0
13 | 0=0, 2=0
13 | 0=0, 2=1
13 | 0=1, 2=0
"""
csi_rels = {}

Expand Down Expand Up @@ -555,7 +577,7 @@ def sample(self, n):
return df

def plot(self, full=False):
"""Plot the CStree. Make sure to set the parameters first.
"""Plot the CStree.
Args:
fill (bool): If True, the tree is filled with parameters.
Expand Down
55 changes: 30 additions & 25 deletions src/cstrees/stage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
import random
import random

import numpy as np

Expand All @@ -23,16 +23,18 @@ class Stage:
"""

def __init__(self, list_repr, color=None, cards=None) -> None:
self.level = len(list_repr)-1
self.list_repr = list_repr
def __init__(self, stage_repr, color=None, cards=None) -> None:

self.level = len(stage_repr)-1
self.list_repr = stage_repr

# Check if singleton, if so set color black
if all([isinstance(i, int) for i in list_repr]):
if all([isinstance(i, int) for i in self.list_repr]):
self.color = "black"
else:
self.color = color
self.color = color
self.probs = None
self.cards=cards
self.cards = cards
self.csi = self.to_csi()

def __hash__(self) -> int:
Expand Down Expand Up @@ -72,7 +74,7 @@ def size(self):
if type(e) is set:
s *= len(e)
return s

def is_singleton(self):
"""
Checks if the stage is a singleton.
Expand Down Expand Up @@ -104,10 +106,11 @@ def to_df(self, column_labels, max_card=None, write_probs=False):
d[column_labels[i]] = [self.list_repr[i]]
else:
d[column_labels[i]] = ["-"]
if (self.probs is not None) and write_probs:

if (self.probs is not None) and write_probs:
df = pd.DataFrame(d, columns=column_labels[:-max_card])
df_prop = pd.DataFrame({"PROB_"+str(i):[prob] for i, prob in enumerate(self.probs)})
df_prop = pd.DataFrame(
{"PROB_"+str(i): [prob] for i, prob in enumerate(self.probs)})
df = pd.concat([df, df_prop], axis=1)
else:
df = pd.DataFrame(d, columns=column_labels)
Expand All @@ -117,7 +120,6 @@ def set_random_params(self, cards):
self.probs = np.random.dirichlet(
[1] * cards[self.level]) # Need to fix this


def __sub__(self, stage):
""" b is typically a sample from the space self.
Expand All @@ -127,9 +129,9 @@ def __sub__(self, stage):
Returns:
list: A list of CSI relations representing the new space.
"""
assert stage.cards is not None # Shouldnt use assert here
assert stage.cards is not None # Shouldnt use assert here
assert self.cards is not None

a = self
b = stage
p = self.level
Expand Down Expand Up @@ -176,7 +178,7 @@ def to_csi(self, labels=None):

ci = csi_relation.CI(sepseta, sepsetb, cond_set, labels=labels)
context = csi_relation.Context(context, labels=labels)

return csi_relation.CSI(ci, context, cards=self.cards)

def intersects(self, stage):
Expand Down Expand Up @@ -206,7 +208,7 @@ def to_cstree_paths(self):

def __str__(self) -> str:
if self.probs is not None:
return str(self.list_repr) + "; probs: " + str(self.probs)+ "; color: " + str(self.color)
return str(self.list_repr) + "; probs: " + str(self.probs) + "; color: " + str(self.color)
return str(self.list_repr)


Expand All @@ -226,7 +228,7 @@ def sample_stage_restr_by_stage(stage: Stage, max_cvars: int, cvar_prob: float,
"""

space = stage.list_repr
levelplus1 = len(space) # this is not the full p?
levelplus1 = len(space) # this is not the full p?

assert (max_cvars <= levelplus1) # < Since at least one cannot be a cvar.
# This may not be true if wa are at very low levels where the level in sthe constraint.
Expand All @@ -236,17 +238,17 @@ def sample_stage_restr_by_stage(stage: Stage, max_cvars: int, cvar_prob: float,
cont_var_counter = 0
# random order here, to not favor low levels.
randorder = list(range(levelplus1))
random.shuffle(randorder)
random.shuffle(randorder)

# TODO: at level 0, it should be possible to have two singleton stages.
for i in range(levelplus1):
ind = randorder[i]
s = space[ind] # a context value (int) or the full set of values.
s = space[ind] # a context value (int) or the full set of values.

if type(s) is int: # This is a restriction of the space.
csilist[ind] = s
cont_var_counter += 1
else:
else:
if cont_var_counter < max_cvars-fixed_cvars: # Make sure not too many context vars
# (i.e. a cond var), pick either one or all.

Expand All @@ -255,14 +257,16 @@ def sample_stage_restr_by_stage(stage: Stage, max_cvars: int, cvar_prob: float,
if b == 0: # TODO: this should be able to happen anyway?
csilist[ind] = set(range(cards[ind]))
else:
v = np.random.randint(cards[ind]) # choose a random context value
# choose a random context value
v = np.random.randint(cards[ind])
cont_var_counter += 1
csilist[ind] = v
else: # no more context vars allowed.
else: # no more context vars allowed.
csilist[ind] = set(range(cards[ind]))

return Stage(csilist, cards=stage.cards)



def sample_random_stage(cards: list, level: int, max_contextvars: int, prob: float) -> Stage:
"""Sample a random non-singleton stage.
Expand All @@ -279,7 +283,8 @@ def sample_random_stage(cards: list, level: int, max_contextvars: int, prob: flo

# If the number is smaller than the level, then level is max.
ncont = max_contextvars
if max_contextvars > level-1: # Since not all can be context variables. (i.e. singleton stage)
# Since not all can be context variables. (i.e. singleton stage)
if max_contextvars > level-1:
ncont = level - 1

possible_context_vars = np.random.choice(
Expand All @@ -299,4 +304,4 @@ def sample_random_stage(cards: list, level: int, max_contextvars: int, prob: flo
else:
vals[i] = set(range(cards[i])) # use set here!
s = Stage(vals)
return s
return s

0 comments on commit c32f398

Please sign in to comment.