From 3ae707da5e22a363e55283a57e39113804913928 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Sat, 16 Jan 2021 21:58:06 +0000 Subject: [PATCH] add rgcn. --- example_OAG/GPT_GNN/conv.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/example_OAG/GPT_GNN/conv.py b/example_OAG/GPT_GNN/conv.py index e464869..0c9c14d 100644 --- a/example_OAG/GPT_GNN/conv.py +++ b/example_OAG/GPT_GNN/conv.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable -from torch_geometric.nn import GCNConv, GATConv +from torch_geometric.nn import GCNConv, GATConv, RGCNConv from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import glorot, uniform from torch_geometric.utils import softmax @@ -164,6 +164,8 @@ def __init__(self, conv_name, in_hid, out_hid, num_types, num_relations, n_heads self.base_conv = GCNConv(in_hid, out_hid) elif self.conv_name == 'gat': self.base_conv = GATConv(in_hid, out_hid // n_heads, heads=n_heads) + elif self.conv_name == 'rgcn': + self.base_conv = RGCNConv(in_hid, out_hid, num_relations) def forward(self, meta_xs, node_type, edge_index, edge_type, edge_time): if self.conv_name == 'hgt': return self.base_conv(meta_xs, node_type, edge_index, edge_type, edge_time) @@ -171,5 +173,5 @@ def forward(self, meta_xs, node_type, edge_index, edge_type, edge_time): return self.base_conv(meta_xs, edge_index) elif self.conv_name == 'gat': return self.base_conv(meta_xs, edge_index) - - + elif self.conv_name == 'rgcn': + return self.base_conv(meta_xs, edge_index, edge_type)