Skip to content

Commit

Permalink
Merge pull request #8 from chaitjo/dev
Browse files Browse the repository at this point in the history
ICML camera-ready updates
  • Loading branch information
chaitjo authored Jun 18, 2023
2 parents 490c5b0 + ffc7564 commit ede7198
Show file tree
Hide file tree
Showing 30 changed files with 1,828 additions and 1,509 deletions.
40 changes: 23 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
*Geometric GNN Dojo* is a pedagogical resource for beginners and experts to explore the design space of **Graph Neural Networks for geometric graphs**.

Check out the accompanying paper ['On the Expressive Power of Geometric Graph Neural Networks'](https://arxiv.org/abs/2301.09308), which studies the expressivity and theoretical limits of geometric GNNs.
> Chaitanya K. Joshi*, Cristian Bodnar*, Simon V. Mathis, Taco Cohen, and Pietro Liò. On the Expressive Power of Geometric Graph Neural Networks. *NeurIPS 2022 Workshop on Symmetry and Geometry in Neural Representations.*
> Chaitanya K. Joshi*, Cristian Bodnar*, Simon V. Mathis, Taco Cohen, and Pietro Liò. On the Expressive Power of Geometric Graph Neural Networks. *International Conference on Machine Learning*.
>
>[PDF](https://arxiv.org/pdf/2301.09308.pdf) | [Slides](https://www.chaitjo.com/publication/joshi-2023-expressive/Geometric_GNNs_Slides.pdf) | [Video](https://youtu.be/5ulJMtpiKGc)
**New to geometric GNNs:** try our practical notebook on [*Geometric GNNs 101*](geometric_gnn_101.ipynb), prepared for MPhil students at the University of Cambridge.

<a target="_blank" href="https://colab.research.google.com/github/chaitjo/geometric-gnn-dojo/blob/main/geometric_gnn_101.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab (recommended!)"/>
</a>

## Architectures

The `/src` directory provides unified implementations of several popular geometric GNN architectures:
- Invariant GNNs: [SchNet](https://arxiv.org/abs/1706.08566), [DimeNet](https://arxiv.org/abs/2003.03123)
The `/models` directory provides unified implementations of several popular geometric GNN architectures:
- Invariant GNNs: [SchNet](https://arxiv.org/abs/1706.08566), [DimeNet](https://arxiv.org/abs/2003.03123), [SphereNet](https://arxiv.org/abs/2102.05013)
- Equivariant GNNs using cartesian vectors: [E(n) Equivariant GNN](https://proceedings.mlr.press/v139/satorras21a.html), [GVP-GNN](https://arxiv.org/abs/2009.01411)
- Equivariant GNNs using spherical tensors: [Tensor Field Network](https://arxiv.org/abs/1802.08219), [MACE](http://arxiv.org/abs/2206.07697)
- 🔥 Your new geometric GNN architecture?
Expand Down Expand Up @@ -76,17 +76,23 @@ pip install torch-geometric
├── geometric_gnn_101.ipynb # A gentle introduction to Geometric GNNs
|
├── experiments # Synthetic experiments
├── incompleteness.ipynb # Experiment on counterexamples from Pozdnyakov et al.
| |
│ ├── kchains.ipynb # Experiment on k-chains
│ └── rotsym.ipynb # Experiment on rotationally symmetric structures
│ ├── rotsym.ipynb # Experiment on rotationally symmetric structures
│ ├── incompleteness.ipynb # Experiment on counterexamples from Pozdnyakov et al.
| └── utils # Helper functions for training, plotting, etc.
|
└── src # Geometric GNN models library
├── models.py # Models built using layers
├── gvp_layers.py # Layers for GVP-GNN
├── egnn_layers.py # Layers for E(n) Equivariant GNN
├── tfn_layers.py # Layers for Tensor Field Networks
├── modules # Layers for MACE
└── utils # Helper functions for training, plotting, etc.
└── models # Geometric GNN models library
|
├── schnet.py # SchNet model
├── dimenet.py # DimeNet model
├── spherenet.py # SphereNet model
├── egnn.py # E(n) Equivariant GNN model
├── gvpgnn.py # GVP-GNN model
├── tfn.py # Tensor Field Network model
├── mace.py # MACE model
├── layers # Layers for each model
└── modules # Modules and layers for MACE
```


Expand All @@ -99,10 +105,10 @@ We welcome your questions and feedback via email or GitHub Issues.
## Citation

```
@article{joshi2022expressive,
@inproceedings{joshi2023expressive,
title={On the Expressive Power of Geometric Graph Neural Networks},
author={Joshi, Chaitanya K. and Bodnar, Cristian and Mathis, Simon V. and Cohen, Taco and Liò, Pietro},
journal={NeurIPS Workshop on Symmetry and Geometry in Neural Representations},
year={2022},
booktitle={International Conference on Machine Learning},
year={2023},
}
```
```
Binary file modified experiments/fig/axes-of-expressivity.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiments/fig/incompleteness.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
406 changes: 78 additions & 328 deletions experiments/incompleteness.ipynb

Large diffs are not rendered by default.

216 changes: 24 additions & 192 deletions experiments/kchains.ipynb

Large diffs are not rendered by default.

142 changes: 40 additions & 102 deletions experiments/rotsym.ipynb

Large diffs are not rendered by default.

File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
import random
from tqdm import tqdm
from tqdm.autonotebook import tqdm # from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score

Expand Down
7 changes: 7 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from models.schnet import SchNetModel
from models.dimenet import DimeNetPPModel
from models.spherenet import SphereNetModel
from models.egnn import EGNNModel
from models.gvpgnn import GVPGNNModel
from models.tfn import TFNModel
from models.mace import MACEModel
105 changes: 105 additions & 0 deletions models/dimenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Callable, Union

import torch
from torch.nn import functional as F
from torch_geometric.nn import DimeNetPlusPlus
from torch_scatter import scatter


class DimeNetPPModel(DimeNetPlusPlus):
"""
DimeNet model from "Directional message passing for molecular graphs".
This class extends the DimeNetPlusPlus base class for PyG.
"""
def __init__(
self,
hidden_channels: int = 128,
in_dim: int = 1,
out_dim: int = 1,
num_layers: int = 4,
int_emb_size: int = 64,
basis_emb_size: int = 8,
out_emb_channels: int = 256,
num_spherical: int = 7,
num_radial: int = 6,
cutoff: float = 10,
max_num_neighbors: int = 32,
envelope_exponent: int = 5,
num_before_skip: int = 1,
num_after_skip: int = 2,
num_output_layers: int = 3,
act: Union[str, Callable] = 'swish'
):
"""
Initializes an instance of the DimeNetPPModel class with the provided parameters.
Parameters:
- hidden_channels (int): Number of channels in the hidden layers (default: 128)
- in_dim (int): Input dimension of the model (default: 1)
- out_dim (int): Output dimension of the model (default: 1)
- num_layers (int): Number of layers in the model (default: 4)
- int_emb_size (int): Embedding size for interaction features (default: 64)
- basis_emb_size (int): Embedding size for basis functions (default: 8)
- out_emb_channels (int): Number of channels in the output embeddings (default: 256)
- num_spherical (int): Number of spherical harmonics (default: 7)
- num_radial (int): Number of radial basis functions (default: 6)
- cutoff (float): Cutoff distance for interactions (default: 10)
- max_num_neighbors (int): Maximum number of neighboring atoms to consider (default: 32)
- envelope_exponent (int): Exponent of the envelope function (default: 5)
- num_before_skip (int): Number of layers before the skip connections (default: 1)
- num_after_skip (int): Number of layers after the skip connections (default: 2)
- num_output_layers (int): Number of output layers (default: 3)
- act (Union[str, Callable]): Activation function (default: 'swish' or callable)
Note:
- The `act` parameter can be either a string representing a built-in activation function,
or a callable object that serves as a custom activation function.
"""
super().__init__(
hidden_channels,
out_dim,
num_layers,
int_emb_size,
basis_emb_size,
out_emb_channels,
num_spherical,
num_radial,
cutoff,
max_num_neighbors,
envelope_exponent,
num_before_skip,
num_after_skip,
num_output_layers,
act
)

def forward(self, batch):

i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets(
batch.edge_index, num_nodes=batch.atoms.size(0))

# Calculate distances.
dist = (batch.pos[i] - batch.pos[j]).pow(2).sum(dim=-1).sqrt()

# Calculate angles.
pos_i = batch.pos[idx_i]
pos_ji, pos_ki = batch.pos[idx_j] - pos_i, batch.pos[idx_k] - pos_i
a = (pos_ji * pos_ki).sum(dim=-1)
b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
angle = torch.atan2(b, a)

rbf = self.rbf(dist)
sbf = self.sbf(dist, angle, idx_kj)

# Embedding block.
x = self.emb(batch.atoms, rbf, i, j)
P = self.output_blocks[0](x, rbf, i, num_nodes=batch.pos.size(0))

# Interaction blocks.
for interaction_block, output_block in zip(self.interaction_blocks,
self.output_blocks[1:]):
x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
P += output_block(x, rbf, i)

return P.sum(dim=0) if batch is None else scatter(P, batch.batch, dim=0)
87 changes: 87 additions & 0 deletions models/egnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
from torch.nn import functional as F
from torch_geometric.nn import global_add_pool, global_mean_pool

from models.layers.egnn_layer import EGNNLayer


class EGNNModel(torch.nn.Module):
"""
E-GNN model from "E(n) Equivariant Graph Neural Networks".
"""
def __init__(
self,
num_layers: int = 5,
emb_dim: int = 128,
in_dim: int = 1,
out_dim: int = 1,
activation: str = "relu",
norm: str = "layer",
aggr: str = "sum",
pool: str = "sum",
residual: bool = True,
equivariant_pred: bool = False
):
"""
Initializes an instance of the EGNNModel class with the provided parameters.
Parameters:
- num_layers (int): Number of layers in the model (default: 5)
- emb_dim (int): Dimension of the node embeddings (default: 128)
- in_dim (int): Input dimension of the model (default: 1)
- out_dim (int): Output dimension of the model (default: 1)
- activation (str): Activation function to be used (default: "relu")
- norm (str): Normalization method to be used (default: "layer")
- aggr (str): Aggregation method to be used (default: "sum")
- pool (str): Global pooling method to be used (default: "sum")
- residual (bool): Whether to use residual connections (default: True)
- equivariant_pred (bool): Whether it is an equivariant prediction task (default: False)
"""
super().__init__()
self.equivariant_pred = equivariant_pred
self.residual = residual

# Embedding lookup for initial node features
self.emb_in = torch.nn.Embedding(in_dim, emb_dim)

# Stack of GNN layers
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
self.convs.append(EGNNLayer(emb_dim, activation, norm, aggr))

# Global pooling/readout function
self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]

if self.equivariant_pred:
# Linear predictor for equivariant tasks using geometric features
self.pred = torch.nn.Linear(emb_dim + 3, out_dim)
else:
# MLP predictor for invariant tasks using only scalar features
self.pred = torch.nn.Sequential(
torch.nn.Linear(emb_dim, emb_dim),
torch.nn.ReLU(),
torch.nn.Linear(emb_dim, out_dim)
)

def forward(self, batch):

h = self.emb_in(batch.atoms) # (n,) -> (n, d)
pos = batch.pos # (n, 3)

for conv in self.convs:
# Message passing layer
h_update, pos_update = conv(h, pos, batch.edge_index)

# Update node features (n, d) -> (n, d)
h = h + h_update if self.residual else h_update

# Update node coordinates (no residual) (n, 3) -> (n, 3)
pos = pos_update

if not self.equivariant_pred:
# Select only scalars for invariant prediction
out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d)
else:
out = self.pool(torch.cat([h, pos], dim=-1), batch.batch)

return self.pred(out) # (batch_size, out_dim)
Loading

0 comments on commit ede7198

Please sign in to comment.