Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added AP-GCN #104

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cogdl/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ CogDL now supports the following models for different tasks:

- unsupervised node classification: ProNE [(Zhang et al, IJCAI'19)](https://www.ijcai.org/Proceedings/2019/0594.pdf), NetMF [(Qiu et al, WSDM'18)](http://arxiv.org/abs/1710.02971), Node2vec [(Grover et al, KDD'16)](http://dl.acm.org/citation.cfm?doid=2939672.2939754), NetSMF [(Qiu et at, WWW'19)](https://arxiv.org/abs/1906.11156), DeepWalk [(Perozzi et al, KDD'14)](http://arxiv.org/abs/1403.6652), LINE [(Tang et al, WWW'15)](http://arxiv.org/abs/1503.03578), Hope [(Ou et al, KDD'16)](http://dl.acm.org/citation.cfm?doid=2939672.2939751), SDNE [(Wang et al, KDD'16)](https://www.kdd.org/kdd2016/papers/files/rfp0191-wangAemb.pdf), GraRep [(Cao et al, CIKM'15)](http://dl.acm.org/citation.cfm?doid=2806416.2806512), DNGR [(Cao et al, AAAI'16)](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/download/12423/11715).

- semi-supervised node classification: SGC-PN [(Zhao & Akoglu, 2019)](https://arxiv.org/abs/1909.12223), Graph U-Net [(Gao et al., 2019)](https://arxiv.org/abs/1905.05178), MixHop [(Abu-El-Haija et al., ICML'19)](https://arxiv.org/abs/1905.00067), DR-GAT [(Zou et al., 2019)](https://arxiv.org/abs/1907.02237), GAT [(Veličković et al., ICLR'18)](https://arxiv.org/abs/1710.10903), DGI [(Veličković et al., ICLR'19)](https://arxiv.org/abs/1809.10341), GCN [(Kipf et al., ICLR'17)](https://arxiv.org/abs/1609.02907), GraphSAGE [(Hamilton et al., NeurIPS'17)](https://arxiv.org/abs/1706.02216), Chebyshev [(Defferrard et al., NeurIPS'16)](https://arxiv.org/abs/1606.09375).
- semi-supervised node classification: AP-GCN [(Spinelli et al, IEEE Transaction on Neural Networks and Learning Systems'20)](https://arxiv.org/abs/2002.10306), SGC-PN [(Zhao & Akoglu, 2019)](https://arxiv.org/abs/1909.12223), Graph U-Net [(Gao et al., 2019)](https://arxiv.org/abs/1905.05178), MixHop [(Abu-El-Haija et al., ICML'19)](https://arxiv.org/abs/1905.00067), DR-GAT [(Zou et al., 2019)](https://arxiv.org/abs/1907.02237), GAT [(Veličković et al., ICLR'18)](https://arxiv.org/abs/1710.10903), DGI [(Veličković et al., ICLR'19)](https://arxiv.org/abs/1809.10341), GCN [(Kipf et al., ICLR'17)](https://arxiv.org/abs/1609.02907), GraphSAGE [(Hamilton et al., NeurIPS'17)](https://arxiv.org/abs/1706.02216), Chebyshev [(Defferrard et al., NeurIPS'16)](https://arxiv.org/abs/1606.09375).

- heterogeneous node classification: GTN [(Yun et al, NeurIPS'19)](https://arxiv.org/abs/1911.06455), HAN [(Xiao et al, WWW'19)](https://arxiv.org/abs/1903.07293), PTE [(Tang et al, KDD'15)](https://arxiv.org/abs/1508.00200), Metapath2vec [(Dong et al, KDD'17)](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf), Hin2vec [(Fu et al, CIKM'17)](https://dl.acm.org/doi/10.1145/3132847.3132953).

Expand Down
161 changes: 161 additions & 0 deletions cogdl/models/nn/ap_gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from .. import BaseModel, register_model
from torch_geometric.utils.dropout import dropout_adj
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.nn.conv import MessagePassing
from torch.nn import ModuleList, Dropout, ReLU, Linear


class AdaptivePropagation(MessagePassing):
def __init__(self, niter, h_size, **kwargs):
super(AdaptivePropagation, self).__init__(aggr='add', **kwargs)

self.niter = niter
self.halt = Linear(h_size, 1)
self.reg_params = list(self.halt.parameters())
self.dropout = Dropout()
self.reset_parameters()

def reset_parameters(self):
self.halt.reset_parameters()
x = (self.niter + 1) // 1
b = math.log((1 / x) / (1 - (1 / x)))
self.halt.bias.data.fill_(b)

def forward(self, local_preds: torch.FloatTensor, edge_index):
sz = local_preds.size(0)
steps = torch.ones(sz).to(local_preds.device)
sum_h = torch.zeros(sz).to(local_preds.device)
continue_mask = torch.ones(sz, dtype=torch.bool).to(local_preds.device)
x = torch.zeros_like(local_preds).to(local_preds.device)

prop = self.dropout(local_preds)
for _ in range(self.niter):

old_prop = prop
continue_fmask = continue_mask.type('torch.FloatTensor').to(local_preds.device)

drop_edge_index, _ = dropout_adj(edge_index, training=self.training)
drop_edge_index, _ = add_self_loops(drop_edge_index, num_nodes=sz)
row, col = drop_edge_index
deg = degree(col, sz, dtype=prop.dtype)
deg_inv_sqrt = deg.pow(-0.5)
drop_norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

prop = self.propagate(drop_edge_index, x=prop, norm=drop_norm)

h = torch.sigmoid(self.halt(prop)).t().squeeze()
prob_mask = (((sum_h + h) < 0.99) & continue_mask).squeeze()
prob_fmask = prob_mask.type('torch.FloatTensor').to(local_preds.device)

steps = steps + prob_fmask
sum_h = sum_h + prob_fmask * h

final_iter = steps < self.niter

condition = prob_mask & final_iter
p = torch.where(condition, sum_h, 1 - sum_h)

to_update = self.dropout(continue_fmask)[:, None]
x = x + (prop * p[:, None] +
old_prop * (1 - p)[:, None]) * to_update

continue_mask = continue_mask & prob_mask

if (~continue_mask).all():
break

x = x / steps[:, None]

return x, steps, (1 - sum_h)

def message(self, x_j, norm):
return norm.view(-1, 1) * x_j


@register_model("ap_gcn")
class AP_GCN(BaseModel):
"""
Model Name: Adaptive Propagation Graph Convolutional Network (AP-GCN)
Paper link: https://arxiv.org/abs/2002.10306
"""

@staticmethod
def add_args(parser):
parser.add_argument("--hidden-size", type=int, default=64)
parser.add_argument("--batch-size", type=int, default=20)
parser.add_argument("--train-ratio", type=float, default=0.7)
parser.add_argument("--test-ratio", type=float, default=0.1)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--niter", type=int, default=10)
parser.add_argument("--prop_penalty", type=float, default=0.005)
parser.add_argument("--lr", type=float, default=0.001)

@classmethod
def build_model_from_args(cls, args):
return cls(
args.num_features,
args.hidden_size,
args.num_classes,
args.dropout,
args.niter,
args.prop_penalty,
args.weight_decay,
)

def __init__(self, in_feats, hidden_dim, out_feats, dropout, niter, prop_penalty, weight_decay):
super(AP_GCN, self).__init__()

num_features = [in_feats] + [hidden_dim] + [out_feats]

layers = []
for in_features, out_features in zip(num_features[:-1], num_features[1:]):
layers.append(nn.Linear(in_features, out_features))

self.prop = AdaptivePropagation(niter, out_feats)
self.prop_penalty = prop_penalty
self.weight_decay = weight_decay
self.layers = ModuleList(layers)
self.reg_params = list(layers[0].parameters())
self.non_reg_params = list([p for layer in layers[1:] for p in layer.parameters()])

self.dropout = Dropout(p=dropout)
self.act_fn = ReLU()

self.reset_parameters()

def reset_parameters(self):
self.prop.reset_parameters()
for layer in self.layers:
layer.reset_parameters()

def forward(self, x, adj):
for i, layer in enumerate(self.layers):
x = layer(self.dropout(x))

if i == len(self.layers) - 1:
break

x = self.act_fn(x)

x, steps, reminders = self.prop(x, adj)
return x, steps, reminders

def node_classification_loss(self, data):
x, steps, reminders = self.forward(data.x, data.edge_index)
x = F.log_softmax(x, dim=-1)
loss = F.nll_loss(x[data.train_mask], data.y[data.train_mask])
l2_reg = sum((torch.sum(param ** 2) for param in self.reg_params))
loss += self.weight_decay / 2 * l2_reg + self.prop_penalty * (
steps[data.train_mask] + reminders[data.train_mask]).mean()
return loss

def predict(self, data):
x, _, _ = self.forward(data.x, data.edge_index)
return x
61 changes: 61 additions & 0 deletions examples/gnn_models/ap_gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from cogdl.datasets import build_dataset
from cogdl.tasks import build_task
from cogdl.utils import build_args_from_dict, print_result, set_random_seed
import torch
DATASET_REGISTRY = {}


def build_default_args_for_node_classification(dataset):
cpu = not torch.cuda.is_available()
args = {
"lr": 0.001,
"cpu": cpu,
"device_id": [0],
"weight_decay": 0.001,
"max_epoch": 2000,
"patience": 20,
"seed": [0],
"dropout": 0.5,
"hidden_size": 64,
"niter": 10,
"prop_penalty": 0.005,
"missing_rate": -1,
"task": "node_classification",
"model": "ap_gcn",
"dataset": dataset,
}
return build_args_from_dict(args)


def register_func(name):
def register_func_name(func):
DATASET_REGISTRY[name] = func
return func

return register_func_name


@register_func('citeseer')
def citeseer_config(args):
return args


def run(dataset_name):
args = build_default_args_for_node_classification(dataset_name)
args = DATASET_REGISTRY[dataset_name](args)
dataset = build_dataset(args)
results = []
for seed in args.seed:
set_random_seed(seed)
task = build_task(args, dataset=dataset)
result = task.train()
results.append(result)
return results


if __name__ == "__main__":
datasets = ['citeseer']
results = []
for x in datasets:
results += run(x)
print_result(results, datasets, "ap_gcn")
1 change: 1 addition & 0 deletions match.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
node_classification:
- model:
- ap_gcn
- gdc_gcn
- gcn
- gat
Expand Down
14 changes: 14 additions & 0 deletions tests/tasks/test_node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,19 @@ def test_dropedge_inceptiongcn_cora():
ret = task.train()
assert 0 <= ret["Acc"] <= 1

def test_ap_gcn_citeseer():
args = get_default_args()
args.weight_decay = 0.001
args.niter = 10
args.prop_penalty = 0.005
args.lr = 0.001
args.task = "node_classification"
args.dataset = "citeseer"
args.model = "ap_gcn"

task = build_task(args)
ret = task.train()
assert 0 <= ret["Acc"] <= 1

def test_pprgo_cora():
args = get_default_args()
Expand Down Expand Up @@ -707,4 +720,5 @@ def test_pprgo_cora():
test_dropedge_inceptiongcn_cora()
test_dropedge_densegcn_cora()
test_unet_cora()
test_ap_gcn_citeseer()
test_pprgo_cora()