From 386f30b1bec11f66f42d9ed2c2dcd80adc87923c Mon Sep 17 00:00:00 2001 From: Thijs Kuipers Date: Tue, 12 Sep 2023 17:57:06 +0200 Subject: [PATCH] added correct kwargs to lifting layer --- gconv/gnn/kernels/kernel.py | 2 +- gconv/gnn/modules/gconv.py | 2 + gconv/tests/test_equivariance.py | 96 ++++++++++++++++++++++++++++++++ gconv/tests/test_gconv_e2.py | 2 +- gconv/tests/test_gconv_e3.py | 2 +- gconv/tests/test_kernel_se3.py | 2 +- 6 files changed, 102 insertions(+), 4 deletions(-) create mode 100644 gconv/tests/test_equivariance.py diff --git a/gconv/gnn/kernels/kernel.py b/gconv/gnn/kernels/kernel.py index 898e5cb..93cabc6 100644 --- a/gconv/gnn/kernels/kernel.py +++ b/gconv/gnn/kernels/kernel.py @@ -181,7 +181,7 @@ def forward(self, H) -> Tensor: weight = self.sample_Rn( self.weight.repeat_interleave(H.shape[0], dim=0), H_product.repeat(self.out_channels, *product_dims), - **self.sample_H_kwargs, + **self.sample_Rn_kwargs, ).view( self.out_channels, num_H, self.in_channels // self.groups, *self.kernel_size ) diff --git a/gconv/gnn/modules/gconv.py b/gconv/gnn/modules/gconv.py index ebe6dc5..2bdf7ff 100644 --- a/gconv/gnn/modules/gconv.py +++ b/gconv/gnn/modules/gconv.py @@ -5,6 +5,8 @@ """ from __future__ import annotations +from matplotlib import pyplot as plt + from gconv.gnn.kernels import ( GroupKernel, GLiftingKernel, diff --git a/gconv/tests/test_equivariance.py b/gconv/tests/test_equivariance.py new file mode 100644 index 0000000..a95320d --- /dev/null +++ b/gconv/tests/test_equivariance.py @@ -0,0 +1,96 @@ +import sys + +sys.path.append("..") + + +from gconv.gnn import GLiftingConvSE3 +from gconv.geometry.groups import so3 as R +from gconv.gnn import functional as gF +import torch + +from matplotlib import pyplot as plt + +from torch.nn.functional import grid_sample + + +def plot_activations(activations): + B, _, H, *_ = activations.shape + fig = plt.figure() + for i in range(B): + for j in range(H): + ax = fig.add_subplot(B, H, 1 + j + i * H) + ax.imshow(activations[i, 1, j, 2].detach().numpy()) + ax.axis(False) + plt.show() + + +def test_se3_lifting_conv(): + torch.manual_seed(0) + + batch_size = 1 + in_channels = 2 + out_channels = 3 + kernel_size = 5 + group_kernel_size = 4 + groups = 1 + bias = False + + input = torch.zeros(batch_size, in_channels, 5, 5, 5) + input[:, :, 2, 2, :] = 1 + + grid_H = torch.Tensor( + [ + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + [ + [-1, 0, 0], + [0, -1, 0], + [0, 0, 1], + ], + ] + ) + from math import pi + + grid_H = R.matrix_z(torch.linspace(0, 2 * pi, 5)[:-1]) + + grid_R3 = gF.create_grid_R3(5) + + grid_R3_rotated = R.left_apply_to_R3(grid_H, grid_R3) + input_rotated = grid_sample( + input.repeat(grid_H.shape[0], 1, 1, 1, 1), + grid_R3_rotated, + mode="nearest", + padding_mode="zeros", + ) + + model = GLiftingConvSE3( + in_channels, + out_channels, + kernel_size, + group_kernel_size=group_kernel_size, + padding="same", + groups=groups, + bias=bias, + sampling_mode="nearest", + sampling_padding_mode="zeros", + mask=True, + permute_output_grid=False, + ) + + output, H = model(input_rotated, grid_H) + + # plot_activations(input[:, :, None]) + # plot_activations(input_rotated[:, :, None]) + print(output.shape) + plot_activations(output) + + +def main(): + test_se3_lifting_conv() + + +if __name__ == "__main__": + main() diff --git a/gconv/tests/test_gconv_e2.py b/gconv/tests/test_gconv_e2.py index b4632ee..8667bd3 100644 --- a/gconv/tests/test_gconv_e2.py +++ b/gconv/tests/test_gconv_e2.py @@ -4,7 +4,7 @@ import torch -from gconv.nn import GLiftingConvE2, GSeparableConvE2, GConvE2 +from gconv.gnn import GLiftingConvE2, GSeparableConvE2, GConvE2 from gconv.geometry.groups import o2 diff --git a/gconv/tests/test_gconv_e3.py b/gconv/tests/test_gconv_e3.py index 14db911..1fe31ca 100644 --- a/gconv/tests/test_gconv_e3.py +++ b/gconv/tests/test_gconv_e3.py @@ -5,7 +5,7 @@ import torch -from gconv.nn import GLiftingConvE3, GSeparableConvE3, GConvE3 +from gconv.gnn import GLiftingConvE3, GSeparableConvE3, GConvE3 from gconv.geometry.groups import o3 diff --git a/gconv/tests/test_kernel_se3.py b/gconv/tests/test_kernel_se3.py index 66229d3..e907bf5 100644 --- a/gconv/tests/test_kernel_se3.py +++ b/gconv/tests/test_kernel_se3.py @@ -2,7 +2,7 @@ sys.path.append("..") -from gconv.nn import kernels +from gconv.gnn import kernels from gconv.geometry.groups import so3