-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathMPNN_readout.py
92 lines (84 loc) · 3.24 KB
/
MPNN_readout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# MPNN
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
import torch.nn as nn
from dgl.nn.pytorch import Set2Set
from dgllife.model.gnn import MPNNGNN
__all__ = ['MPNN_readout']
# pylint: disable=W0221
class MPNN_readout(nn.Module):
"""MPNN for regression and classification on graphs.
MPNN is introduced in `Neural Message Passing for Quantum Chemistry
<https://arxiv.org/abs/1704.01212>`__.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_out_feats : int
Size for the output node representations. Default to 64.
edge_hidden_feats : int
Size for the hidden edge representations. Default to 128.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
num_step_message_passing : int
Number of message passing steps. Default to 6.
num_step_set2set : int
Number of set2set steps. Default to 6.
num_layer_set2set : int
Number of set2set layers. Default to 3.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_out_feats=64,
edge_hidden_feats=128,
n_tasks=1,
num_step_message_passing=6,
num_step_set2set=6,
dropout=0,
num_layer_set2set=3, descriptor_feats=0):
super(MPNN_readout, self).__init__()
self.gnn = MPNNGNN(node_in_feats=node_in_feats,
node_out_feats=node_out_feats,
edge_in_feats=edge_in_feats,
edge_hidden_feats=edge_hidden_feats,
num_step_message_passing=num_step_message_passing)
self.readout = Set2Set(input_dim=node_out_feats,
n_iters=num_step_set2set,
n_layers=num_layer_set2set)
self.predict = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(2 * node_out_feats + descriptor_feats, node_out_feats),
nn.ReLU(),
nn.BatchNorm1d(node_out_feats),
nn.Linear(node_out_feats, n_tasks)
)
def forward(self, g, node_feats, edge_feats, concat_feats=None):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features.
Returns
-------
float32 tensor of shape (G, n_tasks)
Prediction for the graphs in the batch. G for the number of graphs.
"""
node_feats = self.gnn(g, node_feats, edge_feats)
graph_feats = self.readout(g, node_feats)
if concat_feats != None:
final_feats = torch.cat((graph_feats, concat_feats), dim=1)
else:
final_feats = graph_feats
return self.predict(final_feats)