-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgcs_convolution.py
181 lines (149 loc) · 5.1 KB
/
gcs_convolution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import logging
from tensorflow.keras import backend as K # noqa
from gns.config.settings import settings_fabric
from gns.layer.convolution import ConvolutionalGeneralLayer
from gns.utils.normalized_adjacency_matrix import normalized_adjacency_matrix
from gns.utils.dot_production_modal import dot_production_modal
settings = settings_fabric()
logger = logging.getLogger(__name__)
class GCSConvolutionalGeneralLayer(ConvolutionalGeneralLayer):
"""
A special `GraphConv` layer with a trainable skip connection.
Models:
single
disjoint
mixed
batch
Input parameters:
Node features of shape `([batch], n_nodes, n_node_features)`;
Normalized adjacency matrix of shape `([batch], n_nodes, n_nodes)`
(can be computed with `gns.utils.normalized_adjacency_matrix()`)
Output parameters:
Node features with the same shape as the input, but with the last dimension changed to `channels`.
"""
def __init__(
self,
channels,
activation=None,
use_bias=True,
kernel_initializer=settings.initializers.glorot_uniform,
bias_initializer=settings.initializers.zeros,
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
):
"""
Args:
`channels`: number of output channels;
`activation`: activation function;
`use_bias`: bool, add a bias vector to the output;
`kernel_initializer`: initializer for the weights;
`bias_initializer`: initializer for the bias vector;
`kernel_regularizer`: regularization applied to the weights;
`bias_regularizer`: regularization applied to the bias vector;
`activity_regularizer`: regularization applied to the output;
`kernel_constraint`: constraint applied to the weights;
`bias_constraint`: constraint applied to the bias vector.
"""
super().__init__(
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs
)
self.channels = channels
def build(self, input_shape):
"""
Build layer.
Args:
input_shape: input shape
Returns:
"""
assert len(input_shape) >= 2
input_dim = input_shape[0][-1]
logger.info("Create the first kernel.")
self.kernel_1 = self.add_weight(
shape=(input_dim, self.channels),
initializer=self.kernel_initializer,
name=settings.names.kernel_1,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
logger.info("Create the second kernel.")
self.kernel_2 = self.add_weight(
shape=(input_dim, self.channels),
initializer=self.kernel_initializer,
name=settings.names.kernel_2,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
# If we need a bias
if self.use_bias:
self.bias = self.add_weight(
shape=(self.channels,),
initializer=self.bias_initializer,
name=settings.names.bias,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
)
self.built = True
def call(self, inputs, mask=None):
"""
Call layer.
Args:
inputs: inputs
mask: mask
Returns:
"""
x, a = inputs
output = K.dot(x, self.kernel_1)
output = dot_production_modal(a, output)
skip = K.dot(x, self.kernel_2)
output += skip
if self.use_bias:
output = K.bias_add(output, self.bias)
if mask is not None:
output *= mask[0]
output = self.activation(output)
return output
@property
def config(self):
return {"channels": self.channels}
@staticmethod
def preprocess(a):
return normalized_adjacency_matrix(a)
def gsn_convolutional_general_layer_fabric(
channels,
activation=None,
use_bias=True,
kernel_initializer=settings.initializers.glorot_uniform,
bias_initializer=settings.initializers.zeros,
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
):
return GCSConvolutionalGeneralLayer(
channels,
activation,
use_bias,
kernel_initializer,
bias_initializer,
kernel_regularizer,
bias_regularizer,
activity_regularizer,
kernel_constraint,
bias_constraint,
**kwargs
)