From 00f84e8ea3ce250d57a204ba373ad69658aaf090 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Mon, 6 Nov 2023 12:51:57 +0100 Subject: [PATCH] add compose --- src/MessagePassingIPA.jl | 15 +++++++++++++++ test/runtests.jl | 27 ++++++++++++++++++++------- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 43cc629..b8ffef7 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -79,6 +79,21 @@ inverse_transform(rigid::RigidTransformation{T}, y::AbstractArray{T,3}) where {T y .- unsqueeze(rigid.translations, dims=2), ) +""" + compose(rigid1::RigidTransformation, rigid2::RigidTransformation) + +Compose two rigid transformations. +""" +function compose( + rigid1::RigidTransformation{T}, + rigid2::RigidTransformation{T}, +) where {T} + rotations = batched_mul(rigid1.rotations, rigid2.rotations) + translations = + batched_vec(rigid1.rotations, rigid2.translations) + rigid1.translations + return RigidTransformation(rotations, translations) +end + # Invariant point attention # ------------------------- diff --git a/test/runtests.jl b/test/runtests.jl index cd391c1..4b25282 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using MessagePassingIPA +using MessagePassingIPA: RigidTransformation, InvariantPointAttention, transform, inverse_transform, compose, rigid_from_3points using GraphNeuralNetworks: rand_graph using Rotations: RotMatrix using Test @@ -8,17 +8,30 @@ using Test n = 100 rotations = stack(rand(RotMatrix{3,Float32}) for _ in 1:n) translations = randn(Float32, 3, n) - rigid = MessagePassingIPA.RigidTransformation(rotations, translations) + rigid = RigidTransformation(rotations, translations) x = randn(Float32, 3, 12, n) - y = MessagePassingIPA.transform(rigid, x) + y = transform(rigid, x) @test size(x) == size(y) - @test x ≈ MessagePassingIPA.inverse_transform(rigid, y) + @test x ≈ inverse_transform(rigid, y) + + n = 100 + rigid1 = + RigidTransformation(stack(rand(RotMatrix{3,Float32}) + for _ in 1:n), randn(Float32, 3, n)) + rigid2 = + RigidTransformation(stack(rand(RotMatrix{3,Float32}) + for _ in 1:n), randn(Float32, 3, n)) + rigid12 = compose(rigid1, rigid2) + x = randn(Float32, 3, 12, n) + @test transform(rigid12, x) ≈ transform(rigid1, transform(rigid2, x)) + y = transform(rigid12, x) + @test x ≈ inverse_transform(rigid2, inverse_transform(rigid1, y)) end @testset "InvariantPointAttention" begin n_dims_s = 32 n_dims_z = 16 - ipa = MessagePassingIPA.InvariantPointAttention(n_dims_s, n_dims_z) + ipa = InvariantPointAttention(n_dims_s, n_dims_z) n_nodes = 100 n_edges = 500 @@ -31,7 +44,7 @@ using Test x1 = c .+ randn(Float32, 3, n_nodes) x2 = c .+ randn(Float32, 3, n_nodes) x3 = c .+ randn(Float32, 3, n_nodes) - rigid1 = MessagePassingIPA.RigidTransformation(MessagePassingIPA.rigid_from_3points(x1, x2, x3)...) + rigid1 = RigidTransformation(rigid_from_3points(x1, x2, x3)...) @test ipa(g, s, z, rigid1) isa Matrix{Float32} @test size(ipa(g, s, z, rigid1)) == (n_dims_s, n_nodes) @@ -40,7 +53,7 @@ using Test x1 = R * x1 .+ t x2 = R * x2 .+ t x3 = R * x3 .+ t - rigid2 = MessagePassingIPA.RigidTransformation(MessagePassingIPA.rigid_from_3points(x1, x2, x3)...) + rigid2 = RigidTransformation(rigid_from_3points(x1, x2, x3)...) @test ipa(g, s, z, rigid1) ≈ ipa(g, s, z, rigid2) end end