Skip to content

Commit

Permalink
fix bug for graph classification explanation
Browse files Browse the repository at this point in the history
  • Loading branch information
anpolol committed Dec 12, 2023
1 parent 6a74012 commit f4b13ad
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions stable_gnn/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,10 @@ def _variable_selection(
p_values.append(p)

number_candidates = int(top_node * 4)
print('number candidates', number_candidates)
print('pvalues', p_values)
candidate_nodes = np.argpartition(p_values, number_candidates)[0:number_candidates]

print('candidate_nodes', candidate_nodes)
# Round 2
data, neighbors = self._data_generation(
target=None,
Expand All @@ -259,12 +261,16 @@ def _variable_selection(
dependent_nodes.append(node)

top_p = np.min((top_node, len(self.features) - 1))
print('top_p ', top_p)
print('pvalues', p_values)
ind_top_p = np.argpartition(p_values, top_p)[0:top_p]
pgm_nodes = list(ind_top_p)

data = data.rename(columns={"A": 0, "B": 1})
data = data.rename(columns=ind_sub_to_ori)


print('ind_top_p', ind_top_p)
return pgm_nodes, data, candidate_nodes

else:
Expand Down Expand Up @@ -321,8 +327,9 @@ def structure_learning(
:param child: (bool, Optional): If False or None, no-child constraint is applied (default: None)
:return: (BayesianNetwork): Pgm explanation in Bayesian Net form
"""

print('before var selection')
subnodes, data, pgm_stats = self._variable_selection(target, top_node, num_samples, pred_threshold)
print('after var selection',subnodes)

# единственное место, где кастуем к строкам!
data.columns = data.columns.astype(str)
Expand All @@ -332,8 +339,8 @@ def structure_learning(
subnodes = [str(x) for x in subnodes]
subnodes_no_target = [str(node) for node in subnodes if node != target]

mk_blanket = self._search_m_k(data, target, subnodes_no_target.copy())
if child is None:
mk_blanket = self._search_m_k(data, target, subnodes_no_target.copy())
est = HillClimbSearch(data[subnodes_no_target])
pgm_no_target = est.estimate(scoring_method=BicScore(data))
print("estimation", pgm_no_target.nodes(), pgm_no_target.edges())
Expand Down

0 comments on commit f4b13ad

Please sign in to comment.