Skip to content

Commit

Permalink
add rgcn.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Jan 16, 2021
1 parent 9a13d9c commit 3ae707d
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions example_OAG/GPT_GNN/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn import GCNConv, GATConv, RGCNConv
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, uniform
from torch_geometric.utils import softmax
Expand Down Expand Up @@ -164,12 +164,14 @@ def __init__(self, conv_name, in_hid, out_hid, num_types, num_relations, n_heads
self.base_conv = GCNConv(in_hid, out_hid)
elif self.conv_name == 'gat':
self.base_conv = GATConv(in_hid, out_hid // n_heads, heads=n_heads)
elif self.conv_name == 'rgcn':
self.base_conv = RGCNConv(in_hid, out_hid, num_relations)
def forward(self, meta_xs, node_type, edge_index, edge_type, edge_time):
if self.conv_name == 'hgt':
return self.base_conv(meta_xs, node_type, edge_index, edge_type, edge_time)
elif self.conv_name == 'gcn':
return self.base_conv(meta_xs, edge_index)
elif self.conv_name == 'gat':
return self.base_conv(meta_xs, edge_index)


elif self.conv_name == 'rgcn':
return self.base_conv(meta_xs, edge_index, edge_type)

0 comments on commit 3ae707d

Please sign in to comment.