Skip to content

Commit

Permalink
added correct kwargs to lifting layer
Browse files Browse the repository at this point in the history
  • Loading branch information
ThijsKuipers1995 committed Sep 12, 2023
1 parent a3998cf commit 386f30b
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 4 deletions.
2 changes: 1 addition & 1 deletion gconv/gnn/kernels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def forward(self, H) -> Tensor:
weight = self.sample_Rn(
self.weight.repeat_interleave(H.shape[0], dim=0),
H_product.repeat(self.out_channels, *product_dims),
**self.sample_H_kwargs,
**self.sample_Rn_kwargs,
).view(
self.out_channels, num_H, self.in_channels // self.groups, *self.kernel_size
)
Expand Down
2 changes: 2 additions & 0 deletions gconv/gnn/modules/gconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"""
from __future__ import annotations

from matplotlib import pyplot as plt

from gconv.gnn.kernels import (
GroupKernel,
GLiftingKernel,
Expand Down
96 changes: 96 additions & 0 deletions gconv/tests/test_equivariance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import sys

sys.path.append("..")


from gconv.gnn import GLiftingConvSE3
from gconv.geometry.groups import so3 as R
from gconv.gnn import functional as gF
import torch

from matplotlib import pyplot as plt

from torch.nn.functional import grid_sample


def plot_activations(activations):
B, _, H, *_ = activations.shape
fig = plt.figure()
for i in range(B):
for j in range(H):
ax = fig.add_subplot(B, H, 1 + j + i * H)
ax.imshow(activations[i, 1, j, 2].detach().numpy())
ax.axis(False)
plt.show()


def test_se3_lifting_conv():
torch.manual_seed(0)

batch_size = 1
in_channels = 2
out_channels = 3
kernel_size = 5
group_kernel_size = 4
groups = 1
bias = False

input = torch.zeros(batch_size, in_channels, 5, 5, 5)
input[:, :, 2, 2, :] = 1

grid_H = torch.Tensor(
[
[
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
],
[
[-1, 0, 0],
[0, -1, 0],
[0, 0, 1],
],
]
)
from math import pi

grid_H = R.matrix_z(torch.linspace(0, 2 * pi, 5)[:-1])

grid_R3 = gF.create_grid_R3(5)

grid_R3_rotated = R.left_apply_to_R3(grid_H, grid_R3)
input_rotated = grid_sample(
input.repeat(grid_H.shape[0], 1, 1, 1, 1),
grid_R3_rotated,
mode="nearest",
padding_mode="zeros",
)

model = GLiftingConvSE3(
in_channels,
out_channels,
kernel_size,
group_kernel_size=group_kernel_size,
padding="same",
groups=groups,
bias=bias,
sampling_mode="nearest",
sampling_padding_mode="zeros",
mask=True,
permute_output_grid=False,
)

output, H = model(input_rotated, grid_H)

# plot_activations(input[:, :, None])
# plot_activations(input_rotated[:, :, None])
print(output.shape)
plot_activations(output)


def main():
test_se3_lifting_conv()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion gconv/tests/test_gconv_e2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from gconv.nn import GLiftingConvE2, GSeparableConvE2, GConvE2
from gconv.gnn import GLiftingConvE2, GSeparableConvE2, GConvE2
from gconv.geometry.groups import o2


Expand Down
2 changes: 1 addition & 1 deletion gconv/tests/test_gconv_e3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from gconv.nn import GLiftingConvE3, GSeparableConvE3, GConvE3
from gconv.gnn import GLiftingConvE3, GSeparableConvE3, GConvE3
from gconv.geometry.groups import o3


Expand Down
2 changes: 1 addition & 1 deletion gconv/tests/test_kernel_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

sys.path.append("..")

from gconv.nn import kernels
from gconv.gnn import kernels
from gconv.geometry.groups import so3


Expand Down

0 comments on commit 386f30b

Please sign in to comment.