Skip to content

Commit

Permalink
Fix typo in the _categorical_exact function (#1060)
Browse files Browse the repository at this point in the history
This typo would cause the function to error if there was only a single shortest path.
  • Loading branch information
savyajha authored Aug 25, 2023
1 parent b907ed4 commit 02b8b30
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pomegranate/bayesian_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def _categorical_exact(X, sample_weight=None, include_parents=None,
structure=structure)

path = sorted(nx.all_shortest_paths(order_graph, source=(),
target=tuple(range(d)), weight="weight"))[1]
target=tuple(range(d)), weight="weight"))[0]

score, structure = 0, list( None for i in range(d) )
for u, v in zip(path[:-1], path[1:]):
Expand Down
22 changes: 11 additions & 11 deletions tests/test_bayesian_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,33 +1030,33 @@ def test_learn_structure_exact(X):
model = BayesianNetwork(algorithm='exact')
model.fit(X)

assert_tuple_equal(model._parents, ((), (0, 2), (), ()))
assert_tuple_equal(model._parents, ((), (), (0, 1), ()))

assert_array_almost_equal(model.distributions[0].probs,
[[0.4545, 0.5455]], 4)
assert_array_almost_equal(model.distributions[3].probs,
assert_array_almost_equal(model.distributions[1].probs,
[[0.4545, 0.5455]], 4)
assert_array_almost_equal(model.distributions[2].probs,
[[0.5455, 0.4545]], 4)
assert_array_almost_equal(model.distributions[1].probs[0],
assert_array_almost_equal(model.distributions[3].probs,
[[0.4545, 0.5455]], 4)
assert_array_almost_equal(model.distributions[2].probs[0],
[[[0.3333, 0.6667],
[1.0000, 0.0000]],

[[0.3333, 0.6667],
[0.3333, 0.6667]]], 4)
[[0.5, 0.5],
[0.5, 0.5]]], 4)

assert_array_almost_equal(model._factor_graph.factors[0].probs,
[[0.4545, 0.5455]], 4)
assert_array_almost_equal(model._factor_graph.factors[1].probs,
[[0.4545, 0.5455]], 4)
assert_array_almost_equal(model._factor_graph.factors[3].probs,
[[0.4545, 0.5455]], 4)
assert_array_almost_equal(model._factor_graph.factors[2].probs,
[[0.5455, 0.4545]], 4)
assert_array_almost_equal(model._factor_graph.factors[1].probs,
[[[0.0833, 0.1667],
[0.2500, 0.0000]],

[[0.0833, 0.1667],
[0.0833, 0.1667]]], 4)
[[0.1250, 0.1250],
[0.1250, 0.1250]]], 4)


def test_summarize(X, distributions):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_bayesian_network_structure_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ def test_categorical_chow_liu_raises(X, w):

def test_categorical_exact(X):
structure = _categorical_exact(X)
assert_tuple_equal(structure, ((), (0,), (), (0, 1)))
assert_tuple_equal(structure, ((), (0,), (0, 1), ()))

structure = _categorical_exact(X, max_parents=1)
assert_tuple_equal(structure, ((), (0,), (), ()))


def test_categorical_exact_weighted(X, w):
structure = _categorical_exact(X, w)
assert_tuple_equal(structure, ((), (0,), (), (0, 1)))
assert_tuple_equal(structure, ((), (0,), (0, 1), ()))

structure = _categorical_exact(X, w, max_parents=1)
assert_tuple_equal(structure, ((), (0,), (), ()))
Expand All @@ -169,15 +169,15 @@ def test_categorical_exact_weighted(X, w):
def test_categorical_exact_exclude_parents(X):
exclude_parents = ((), (2,), (), (1,))
structure = _categorical_exact(X, exclude_parents=exclude_parents)
assert_tuple_equal(structure, ((), (), (0, 3), (0,)))
assert_tuple_equal(structure, ((), (), (0,), (0, 2)))

structure = _categorical_exact(X, exclude_parents=exclude_parents,
max_parents=1)
assert_tuple_equal(structure, ((), (), (0,), ()))
assert_tuple_equal(structure, ((), (0,), (), ()))

exclude_parents = ((), (2,), (), (0, 1))
structure = _categorical_exact(X, exclude_parents=exclude_parents)
assert_tuple_equal(structure, ((3,), (), (0,3), ()))
assert_tuple_equal(structure, ((2, 3), (), (), (2,)))


def test_categorical_exact_large():
Expand Down

0 comments on commit 02b8b30

Please sign in to comment.