Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
anpolol committed Jan 8, 2024
1 parent a77721a commit ae5b336
Showing 1 changed file with 83 additions and 0 deletions.
83 changes: 83 additions & 0 deletions stable_gnn/model_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,89 @@ def convert_dataset(

return train_dataset, test_dataset, val_dataset, int(n_min)

def _func(x: List[str]) -> Optional[int]:
if x[1] == "y" and len(x[0]) > 1:
return int(x[0][5:])
elif x[0] == "y" and len(x[1]) > 1:
return int(x[1][5:])
return None

def _data_eigen_exctractor(self, dataset: List[Graph]) -> pd.DataFrame:
columns_list = list(map(lambda x: "eigen" + str(x), range(self.n_min)))
data_bamt = pd.DataFrame(columns=columns_list + ["y"])
for gr in dataset:
adj_matrix = to_dense_adj(gr.edge_index)
eig = torch.eig(adj_matrix.reshape(adj_matrix.shape[1], adj_matrix.shape[2]))[0].T[0].T
ordered, indices = torch.sort(eig[: gr.num_nodes], descending=True)
to_append = pd.Series(ordered[: self.n_min].tolist() + gr.y.tolist(), index=data_bamt.columns)
data_bamt = data_bamt.append(to_append, ignore_index=True)

return data_bamt

def _bayesian_network_build(self, data_bamt: pd.DataFrame):
# поиск весов для bamt
for col in data_bamt.columns[: len(data_bamt.columns)]:
data_bamt[col] = data_bamt[col].astype(float)
data_bamt["y"] = data_bamt["y"].astype(int)

bn = Nets.HybridBN(has_logit=True)
discretizer = preprocessing.KBinsDiscretizer(n_bins=10, encode="ordinal", strategy="quantile")
p = Preprocessor([("discretizer", discretizer)])
discretized_data, est = p.apply(data_bamt)

bn.add_nodes(p.info)

params = dict()
params["remove_init_edges"] = self.remove_init_edges

if self.init_edges:
params["init_edges"] = list(map(lambda x: ("eigen" + str(x), "y"), list(range(self.n_min)))) + list(
# type: ignore
map(lambda x: ("y", "eigen" + str(x)), list(range(self.n_min)))
)

if self.white_list:
params["white_list"] = list(map(lambda x: ("eigen" + str(x), "y"), list(range(self.n_min)))) + list(
# type: ignore
map(lambda x: ("y", "eigen" + str(x)), list(range(self.n_min)))
)

bn.add_edges(
discretized_data,
scoring_function=(self.score_func, self.score),
params=params,
)

bn.calculate_weights(discretized_data)
bn.plot("BN1.html")
return bn

@staticmethod
def _convolve(dataset: List[Graph], weights: List[float], left_vertices: List[int]) -> List[Graph]:
new_data = []
for graph in dataset:
adj = to_dense_adj(graph.edge_index)
eigs = torch.eig(adj.reshape(adj.shape[1], adj.shape[2]), True)
eigenvectors = eigs[1]
eig = eigs[0].T[0].T
ordered, indices = torch.sort(eig[: graph.num_nodes], descending=True)
lef = indices[left_vertices]
zeroed = torch.tensor(list(set(range(len(eig))) - set(lef.tolist())))
if len(zeroed) > 0:
eig[zeroed] = 0

for e, d in enumerate(lef):
eig[d] = eig[d] * weights[e]

eigenvalues = torch.diag(eig)
convolved = torch.matmul(torch.matmul(eigenvectors, eigenvalues), eigenvectors.T)

graph.edge_index, graph.edge_weight = dense_to_sparse(convolved)
graph.edge_weight = graph.edge_weight
graph.edge_index = graph.edge_index.type(torch.LongTensor)
new_data.append(graph)
return new_data

@staticmethod
def self_supervised_loss(deg_pred: Tensor, batch: Batch) -> Tensor:
"""
Expand Down

0 comments on commit ae5b336

Please sign in to comment.