Skip to content

Commit

Permalink
Merge pull request #8 from MurrellGroup/batchedtransformations
Browse files Browse the repository at this point in the history
Make non-breaking
  • Loading branch information
AntonOresten authored Sep 18, 2024
2 parents 123d63a + 63eb11c commit e7923d7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
5 changes: 3 additions & 2 deletions src/MessagePassingIPA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ function rigid_from_3points(x1::AbstractMatrix, x2::AbstractMatrix, x3::Abstract
return R, t
end

function rigid_from_3points(::Type{Rigid}, args...)
R, t = rigid_from_3points(args...)
function RigidTransformation(R::AbstractArray{T,3}, t::AbstractArray{T,3}) where T<:Real
Translation(t) Rotation(R)
end

RigidTransformation(R, t::AbstractMatrix) = RigidTransformation(R, reshape(t, 3, 1, :))

# Invariant point attention
# -------------------------

Expand Down
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using MessagePassingIPA: InvariantPointAttention, rigid_from_3points
using MessagePassingIPA: RigidTransformation, InvariantPointAttention, rigid_from_3points
using GraphNeuralNetworks: rand_graph
using BatchedTransformations
using Test
Expand Down Expand Up @@ -38,16 +38,16 @@ using Test
x1 = c .+ randn(Float32, 3, n_nodes)
x2 = c .+ randn(Float32, 3, n_nodes)
x3 = c .+ randn(Float32, 3, n_nodes)
rigid1 = rigid_from_3points(Rigid, x1, x2, x3)
rigid1 = RigidTransformation(rigid_from_3points(x1, x2, x3)...)
@test ipa(g, s, z, rigid1) isa Matrix{Float32}
@test size(ipa(g, s, z, rigid1)) == (n_dims_s, n_nodes)

# check invariance
R, t = values(rand(Float32, Rotation, 3)), randn(Float32, 3, 1)
R, t = values(rand(Float32, Rotation, 3)), randn(Float32, 3)
x1 = R * x1 .+ t
x2 = R * x2 .+ t
x3 = R * x3 .+ t
rigid2 = rigid_from_3points(Rigid, x1, x2, x3)
rigid2 = RigidTransformation(rigid_from_3points(x1, x2, x3)...)
@test ipa(g, s, z, rigid1) ipa(g, s, z, rigid2)
end
end

0 comments on commit e7923d7

Please sign in to comment.