-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmodel.py
94 lines (61 loc) · 2.47 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch.nn as nn
import torch
import math
import numpy as np
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class FSGNN(nn.Module):
def __init__(self,nfeat,nlayers,nhidden,nclass,dropout):
super(FSGNN,self).__init__()
self.fc2 = nn.Linear(nhidden*nlayers,nclass)
self.dropout = dropout
self.act_fn = nn.ReLU()
self.fc1 = nn.ModuleList([nn.Linear(nfeat,int(nhidden)) for _ in range(nlayers)])
self.att = nn.Parameter(torch.ones(nlayers))
self.sm = nn.Softmax(dim=0)
def forward(self,list_mat,layer_norm):
mask = self.sm(self.att)
list_out = list()
for ind, mat in enumerate(list_mat):
tmp_out = self.fc1[ind](mat)
if layer_norm == True:
tmp_out = F.normalize(tmp_out,p=2,dim=1)
tmp_out = torch.mul(mask[ind],tmp_out)
list_out.append(tmp_out)
final_mat = torch.cat(list_out, dim=1)
out = self.act_fn(final_mat)
out = F.dropout(out,self.dropout,training=self.training)
out = self.fc2(out)
return F.log_softmax(out, dim=1)
class FSGNN_Large(nn.Module):
def __init__(self,nfeat,nlayers,nhidden,nclass,dp1,dp2):
super(FSGNN_Large,self).__init__()
self.wt1 = nn.ModuleList([nn.Linear(nfeat,int(nhidden)) for _ in range(nlayers)])
self.fc2 = nn.Linear(nhidden*nlayers,nhidden)
self.fc3 = nn.Linear(nhidden,nclass)
self.dropout1 = dp1
self.dropout2 = dp2
self.act_fn = nn.ReLU()
self.att = nn.Parameter(torch.ones(nlayers))
self.sm = nn.Softmax(dim=0)
def forward(self,list_adj,layer_norm,st=0,end=0):
mask = self.sm(self.att)
mask = torch.mul(len(list_adj),mask)
list_out = list()
for ind, mat in enumerate(list_adj):
mat = mat[st:end,:].cuda()
tmp_out = self.wt1[ind](mat)
if layer_norm == True:
tmp_out = F.normalize(tmp_out,p=2,dim=1)
tmp_out = torch.mul(mask[ind],tmp_out)
list_out.append(tmp_out)
final_mat = torch.cat(list_out, dim=1)
out = self.act_fn(final_mat)
out = F.dropout(out,self.dropout1,training=self.training)
out = self.fc2(out)
out = self.act_fn(out)
out = F.dropout(out,self.dropout2,training=self.training)
out = self.fc3(out)
return F.log_softmax(out, dim=1)
if __name__ == '__main__':
pass