From f4b13addfec678376e779fa8db325ec428fd7082 Mon Sep 17 00:00:00 2001 From: anpolol Date: Tue, 12 Dec 2023 17:53:08 +0300 Subject: [PATCH] fix bug for graph classification explanation --- stable_gnn/explain.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/stable_gnn/explain.py b/stable_gnn/explain.py index 012d245..8918e49 100644 --- a/stable_gnn/explain.py +++ b/stable_gnn/explain.py @@ -232,8 +232,10 @@ def _variable_selection( p_values.append(p) number_candidates = int(top_node * 4) + print('number candidates', number_candidates) + print('pvalues', p_values) candidate_nodes = np.argpartition(p_values, number_candidates)[0:number_candidates] - + print('candidate_nodes', candidate_nodes) # Round 2 data, neighbors = self._data_generation( target=None, @@ -259,12 +261,16 @@ def _variable_selection( dependent_nodes.append(node) top_p = np.min((top_node, len(self.features) - 1)) + print('top_p ', top_p) + print('pvalues', p_values) ind_top_p = np.argpartition(p_values, top_p)[0:top_p] pgm_nodes = list(ind_top_p) data = data.rename(columns={"A": 0, "B": 1}) data = data.rename(columns=ind_sub_to_ori) + + print('ind_top_p', ind_top_p) return pgm_nodes, data, candidate_nodes else: @@ -321,8 +327,9 @@ def structure_learning( :param child: (bool, Optional): If False or None, no-child constraint is applied (default: None) :return: (BayesianNetwork): Pgm explanation in Bayesian Net form """ - + print('before var selection') subnodes, data, pgm_stats = self._variable_selection(target, top_node, num_samples, pred_threshold) + print('after var selection',subnodes) # единственное место, где кастуем к строкам! data.columns = data.columns.astype(str) @@ -332,8 +339,8 @@ def structure_learning( subnodes = [str(x) for x in subnodes] subnodes_no_target = [str(node) for node in subnodes if node != target] - mk_blanket = self._search_m_k(data, target, subnodes_no_target.copy()) if child is None: + mk_blanket = self._search_m_k(data, target, subnodes_no_target.copy()) est = HillClimbSearch(data[subnodes_no_target]) pgm_no_target = est.estimate(scoring_method=BicScore(data)) print("estimation", pgm_no_target.nodes(), pgm_no_target.edges())