Skip to content

Commit

Permalink
fix KeyError in small sample kl-div and prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Markham committed Jan 26, 2024
1 parent 6f1786c commit 37645ae
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 22 deletions.
52 changes: 51 additions & 1 deletion scripts/reproduce_uai
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import os
from importlib.metadata import version
import warnings
from networkx.generators.trees import _num_trees

# import pycurl
import pandas as pd
from causallearn.search.ConstraintBased.PC import pc
from causallearn.utils.cit import chisq
import numpy as np

import cstrees.learning as ctl
import cstrees.scoring as sc
from cstrees.evaluate import kl_divergence
from cstrees import cstree as ct


def learn_cstree(data_path):
Expand Down Expand Up @@ -37,6 +40,51 @@ def learn_cstree(data_path):
return opt_tree


def learn_synth_cstree(data):
# estimate possible context variables and create score tables
pcgraph = pc(data.values, 0.05, "chisq", node_names=data.columns)
poss_cvars = ctl.causallearn_graph_to_posscvars(
pcgraph, labels=data.columns)
score_table, context_scores, _ = sc.order_score_tables(
data, max_cvars=1, alpha_tot=1.0, method="BDeu", poss_cvars=poss_cvars
)

# run Gibbs sampler to get MAP order
orders, scores = ctl.gibbs_order_sampler(5000, score_table)
map_order = orders[scores.index(max(scores))]

# estimate CStree
opt_tree = ctl._optimal_cstree_given_order(map_order, context_scores)

return opt_tree


def kl_exp(cards, samp_size):
true = ct.sample_cstree(
cards,
max_cvars=2,
prob_cvar=0.5,
prop_nonsingleton=1)
true.sample_stage_parameters(alpha=2)
data = true.sample(samp_size)

est = learn_synth_cstree(data)
est.estimate_stage_parameters(data)
est._create_tree()
return kl_divergence(est, true)


def run_kl_experiments():
num_runs = 10

cards = [2] * 5
samp_size = 10000
kl_values = np.empty(10, float)
for run_idx in range(num_runs):
kl_values[run_idx] = kl_exp(cards, samp_size)
return kl_values


# script logic for CLI
if __name__ == "__main__":
# check versions to ensure accurate reproduction
Expand Down Expand Up @@ -66,3 +114,5 @@ if __name__ == "__main__":
# c.setopt(c.WRITEDATA, f)
# c.perform()
# c.close()

np.random.seed(1312)
35 changes: 23 additions & 12 deletions src/cstrees/cstree.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from math import comb
import operator
from itertools import product, pairwise
from functools import reduce
from random import uniform
import logging
import sys
from importlib import reload # Not needed in Python 2
Expand Down Expand Up @@ -40,8 +38,13 @@ def write_minimal_context_graphs_to_files(context_dags, prefix="mygraphs"):
for key, val in context_dags.items():
agraph = nx.nx_agraph.to_agraph(val)
agraph.layout("dot")
agraph.draw(prefix + str(key) + ".png",
args='-Glabel="' + str(key) + '" ')
agraph.draw(
prefix +
str(key) +
".png",
args='-Glabel="' +
str(key) +
'" ')


def plot(graph, layout="dot"):
Expand Down Expand Up @@ -391,8 +394,10 @@ def to_df(self, write_probs=False):

for l, stages in self.stages.items():
for s in stages:
dftmp = s.to_df(labs, max_card=max_card,
write_probs=write_probs)
dftmp = s.to_df(
labs,
max_card=max_card,
write_probs=write_probs)
df = pd.concat([df, dftmp])
df.reset_index(drop=True, inplace=True)

Expand Down Expand Up @@ -874,8 +879,16 @@ def predict(self, partial_observation, return_prob=False):
def _prob_of_outcome(outcome):
nodes = (outcome[:idx] for idx in range(self.p + 1))
edges = pairwise(nodes)
probs = map(
lambda edge: self.tree[edge[0]][edge[1]]["cond_prob"], edges)

def _probs_map(edge):
try:
prob = self.tree[edge[0]][edge[1]]["cond_prob"]
except KeyError:
stage = self.get_stage(edge[0])
prob = stage.probs[edge[1][-1]]
return prob

probs = map(_probs_map, edges)
return reduce(operator.mul, probs)

if return_prob:
Expand Down Expand Up @@ -1044,10 +1057,8 @@ def sample_cstree(

for level, staging in stagings.items():
for i, stage in enumerate(staging):
if (level == -1) or (
(level > 0) and all([isinstance(i, int)
for i in stage.list_repr])
):
if (level == -1) or ((level > 0)
and all([isinstance(i, int) for i in stage.list_repr])):
stage.color = "black"
else:
stage.color = colors[i]
Expand Down
19 changes: 10 additions & 9 deletions src/cstrees/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Evaluate estimated CStrees."""
from itertools import product, pairwise, tee
from functools import reduce
import operator

from scipy.special import rel_entr

Expand All @@ -23,8 +22,16 @@ def _rel_entr_of_outcome(outcome):
edges = pairwise(nodes)

def _probs_map(edge):
est = estimated.tree[edge[0]][edge[1]]["cond_prob"]
tru = true.tree[edge[0]][edge[1]]["cond_prob"]
try:
est = estimated.tree[edge[0]][edge[1]]["cond_prob"]
except KeyError:
stage = estimated.get_stage(edge[0])
est = stage.probs[edge[1][-1]]
try:
tru = true.tree[edge[0]][edge[1]]["cond_prob"]
except KeyError:
stage = true.get_stage(edge[0])
tru = stage.probs[edge[1][-1]]
return est, tru

zipped_probs = map(_probs_map, edges)
Expand All @@ -37,9 +44,3 @@ def _probs_of_outcome(prev_pair, next_pair):
return rel_entr(est_prob_outcome, true_prob_outcome)

return sum(map(_rel_entr_of_outcome, outcomes))


# because CStrees are created on the fly while sampling, computing KL
# divergence (or making prediction) may produce key error; can catch
# these errors and set prob of corresponding outcome to 0? or sample
# more? or some other way to generate full tree?
12 changes: 12 additions & 0 deletions src/cstrees/tests/test_cstrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,15 @@ def test_predict():
assert t.predict(
partial_observation_3, True) == (
(1, 0, 0, 2), 0.5913854582044948)

# test conditional probs exist from small sample
s = ct.sample_cstree(
cards,
max_cvars=2,
prob_cvar=0.5,
prop_nonsingleton=1)
s.sample_stage_parameters(alpha=2)

s.sample(35)

s.predict({}) # shouldn't raise KeyError
14 changes: 14 additions & 0 deletions src/cstrees/tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def test_kl_divergence():

cards = [3, 2, 2, 3]

# KL-divergence is 0 when models are identical
t = ct.sample_cstree(
cards,
max_cvars=2,
Expand All @@ -23,6 +24,7 @@ def test_kl_divergence():

assert kl_divergence(t, t) == 0

# KL-divergence is positive nonzero whet models are different
e = ct.sample_cstree(
cards,
max_cvars=2,
Expand All @@ -32,3 +34,15 @@ def test_kl_divergence():

e.sample(1000)
assert kl_divergence(e, t) > 0

# Conditional probabilities exist even when all outcomes haven't
# been observed
s = ct.sample_cstree(
cards,
max_cvars=2,
prob_cvar=0.5,
prop_nonsingleton=1)
s.sample_stage_parameters(alpha=2)

s.sample(35)
kl_divergence(s, t) # shouldn't raise KeyError

0 comments on commit 37645ae

Please sign in to comment.