-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgin.py
98 lines (82 loc) · 3.14 KB
/
gin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os.path as osp
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.nn import GINConv, global_add_pool, global_mean_pool
import numpy as np
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV, KFold, StratifiedKFold
from sklearn.svm import SVC, LinearSVC
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn import preprocessing
from sklearn.metrics import accuracy_score
import sys
class Encoder(torch.nn.Module):
def __init__(self, num_features, dim, num_gc_layers, device):
super(Encoder, self).__init__()
# num_features = dataset.num_features
# dim = 32
self.num_gc_layers = num_gc_layers
# self.nns = []
self.convs = torch.nn.ModuleList()
self.bns = torch.nn.ModuleList()
self.device = device
for i in range(num_gc_layers):
if i:
nn = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
else:
nn = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
conv = GINConv(nn)
bn = torch.nn.BatchNorm1d(dim)
self.convs.append(conv)
self.bns.append(bn)
def forward(self, x, edge_index, batch):
if x is None:
x = torch.ones((batch.shape[0], 1)).to(self.device)
xs = []
for i in range(self.num_gc_layers):
x = F.relu(self.convs[i](x, edge_index))
x = self.bns[i](x)
xs.append(x)
# if i == 2:
# feature_map = x2
xpool = [global_add_pool(x, batch) for x in xs]
x = torch.cat(xpool, 1)
return x, torch.cat(xs, 1)
def get_embeddings(self, loader, device):
ret = []
y = []
with torch.no_grad():
for data in loader:
data = data[0]
data.to(device)
x, edge_index, batch = data.x, data.edge_index, data.batch
if x is None:
x = torch.ones((batch.shape[0],1)).to(device)
x, _ = self.forward(x, edge_index, batch)
ret.append(x.cpu().numpy())
y.append(data.y.cpu().numpy())
ret = np.concatenate(ret, 0)
y = np.concatenate(y, 0)
return ret, y
def get_embeddings_v(self, loader):
ret = []
y = []
with torch.no_grad():
for n, data in enumerate(loader):
data.to(self.device)
x, edge_index, batch = data.x, data.edge_index, data.batch
if x is None:
x = torch.ones((batch.shape[0],1)).to(self.device)
x_g, x = self.forward(x, edge_index, batch)
x_g = x_g.cpu().numpy()
ret = x.cpu().numpy()
y = data.edge_index.cpu().numpy()
print(data.y)
if n == 1:
break
return x_g, ret, y