-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathlayers.py
121 lines (81 loc) · 4.08 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class HyperGraphAttentionLayerSparse(nn.Module):
def __init__(self, in_features, out_features, dropout, alpha, transfer, concat=True, bias=False):
super(HyperGraphAttentionLayerSparse, self).__init__()
self.dropout = dropout
self.in_features = in_features
self.out_features = out_features
self.alpha = alpha
self.concat = concat
self.transfer = transfer
if self.transfer:
self.weight = Parameter(torch.Tensor(self.in_features, self.out_features))
else:
self.register_parameter('weight', None)
self.weight2 = Parameter(torch.Tensor(self.in_features, self.out_features))
self.weight3 = Parameter(torch.Tensor(self.out_features, self.out_features))
if bias:
self.bias = Parameter(torch.Tensor(self.out_features))
else:
self.register_parameter('bias', None)
self.word_context = nn.Embedding(1, self.out_features)
self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
self.a2 = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
self.leakyrelu = nn.LeakyReLU(self.alpha)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.out_features)
if self.weight is not None:
self.weight.data.uniform_(-stdv, stdv)
self.weight2.data.uniform_(-stdv, stdv)
self.weight3.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
nn.init.uniform_(self.a.data, -stdv, stdv)
nn.init.uniform_(self.a2.data, -stdv, stdv)
nn.init.uniform_(self.word_context.weight.data, -stdv, stdv)
def forward(self, x, adj):
x_4att = x.matmul(self.weight2)
if self.transfer:
x = x.matmul(self.weight)
if self.bias is not None:
x = x + self.bias
N1 = adj.shape[1] #number of edge
N2 = adj.shape[2] #number of node
pair = adj.nonzero().t()
get = lambda i: x_4att[i][adj[i].nonzero().t()[1]]
x1 = torch.cat([get(i) for i in torch.arange(x.shape[0]).long()])
q1 = self.word_context.weight[0:].view(1, -1).repeat(x1.shape[0],1).view(x1.shape[0], self.out_features)
pair_h = torch.cat((q1, x1), dim=-1)
pair_e = self.leakyrelu(torch.matmul(pair_h, self.a).squeeze()).t()
assert not torch.isnan(pair_e).any()
pair_e = F.dropout(pair_e, self.dropout, training=self.training)
e = torch.sparse_coo_tensor(pair, pair_e, torch.Size([x.shape[0], N1, N2])).to_dense()
zero_vec = -9e15*torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec)
attention_edge = F.softmax(attention, dim=2)
edge = torch.matmul(attention_edge, x)
edge = F.dropout(edge, self.dropout, training=self.training)
edge_4att = edge.matmul(self.weight3)
get = lambda i: edge_4att[i][adj[i].nonzero().t()[0]]
y1 = torch.cat([get(i) for i in torch.arange(x.shape[0]).long()])
get = lambda i: x_4att[i][adj[i].nonzero().t()[1]]
q1 = torch.cat([get(i) for i in torch.arange(x.shape[0]).long()])
pair_h = torch.cat((q1, y1), dim=-1)
pair_e = self.leakyrelu(torch.matmul(pair_h, self.a2).squeeze()).t()
assert not torch.isnan(pair_e).any()
pair_e = F.dropout(pair_e, self.dropout, training=self.training)
e = torch.sparse_coo_tensor(pair, pair_e, torch.Size([x.shape[0], N1, N2])).to_dense()
zero_vec = -9e15*torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec)
attention_node = F.softmax(attention.transpose(1,2), dim=2)
node = torch.matmul(attention_node, edge)
if self.concat:
node = F.elu(node)
return node
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'