Skip to content

Commit

Permalink
More testing on tree eval and products (#290)
Browse files Browse the repository at this point in the history
* More Testing on tree eval WIP (#289)

* More testing on tree eval
* products 
* Merge branch '24Q2/twig/kerevaltest' into 24Q2/test/test_mtree_eval2

* cleanup test

* comment plotting

---------

Co-authored-by: Johannes Terblanche <[email protected]>
  • Loading branch information
Affie and Affie authored May 13, 2024
1 parent 5b1a9ff commit c441bc1
Show file tree
Hide file tree
Showing 2 changed files with 320 additions and 8 deletions.
14 changes: 8 additions & 6 deletions src/services/KernelEval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ function projectSymPosDef(c::AbstractMatrix)
s = size(c)
# pretty fast to make or remake isbitstype form matrix
_c = SMatrix{s...}(c)
#TODO likely not intended project here: see AMP#283
issymmetric(_c) ? _c : project(SymmetricPositiveDefinite(s[1]),_c,_c)
end

function MvNormalKernel(
μ::AbstractVector,
μ::AbstractArray,
Σ::AbstractArray,
weight::Real=1.0
)
Expand Down Expand Up @@ -58,7 +59,7 @@ function Base.show(io::IO, mvk::MvNormalKernel)
μ = mean(mvk)
Σ2 = cov(mvk)
# Σ=sqrt(Σ2)
d = length)
d = size(Σ2,1)
print(io, "MvNormalKernel(d=",d)
print(io,",μ=",round.(μ;digits=3))
print(io,",Σ^2=[",round(Σ2[1];digits=3))
Expand Down Expand Up @@ -97,10 +98,11 @@ function distanceMalahanobisSq(
basis=DefaultOrthogonalBasis()
)
δc = distanceMalahanobisCoordinates(M,K,q,basis)
p = mean(K)
ϵ = identity_element(M, q)
X = get_vector(M, ϵ, δc, basis)
return inner(M, p, X, X)
# p = mean(K)
# ϵ = identity_element(M, q)
# X = get_vector(M, ϵ, δc, basis)
# return inner(M, p, X, X)
return δc'*δc
end


Expand Down
314 changes: 312 additions & 2 deletions test/manellic/testManellicTree.jl
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Test
using ApproxManifoldProducts
using Random
using LinearAlgebra
using StaticArrays
using TensorCast
using Manifolds
Expand Down Expand Up @@ -255,7 +256,6 @@ permref = sortperm(pts, by=s->getindex(s,1))
@test norm((pts[mtree.permute] .- mean.(mtree.leaf_kernels)) .|> s->s[1]) < 1e-6



# for (i,v) in enumerate(dict[:evaltest_1_at])
# # @show AMP.evaluate(mtree, [v;]), dict[:evaltest_1_dens][i]
# @test isapprox(dict[:evaltest_1_dens][i], AMP.evaluate(mtree, [v;]))
Expand All @@ -266,6 +266,316 @@ permref = sortperm(pts, by=s->getindex(s,1))
##
end

@testset "Test evaluate MvNormalKernel" begin

M = TranslationGroup(1)
ker = AMP.MvNormalKernel([0.0], [0.5;;])
@test isapprox(
AMP.evaluate(M, ker, [0.1]),
pdf(MvNormal(mean(ker), cov(ker)), [0.1])
)


# Test wrapped cicular distribution
function pdf_wrapped_normal(μ, σ, θ; nwrap=1000)
s = 0.0
for k = -nwrap:nwrap
s += exp(-- μ + 2pi*k)^2 / (2*σ^2))
end
return 1/*sqrt(2pi)) * s
end

M = RealCircleGroup()
ker = AMP.MvNormalKernel([0.0], [0.1;;])
@test isapprox(
AMP.evaluate(M, ker, [0.1]),
pdf_wrapped_normal(mean(ker)[], sqrt(cov(ker))[], 0.1)
)

ker = AMP.MvNormalKernel([0], [2.0;;])
@test isapprox(
AMP.evaluate(M, ker, [0.]),
AMP.evaluate(M, ker, [2pi])
)
#TODO wrapped normal distributions broken
@test_broken isapprox(
pdf_wrapped_normal(mean(ker)[], sqrt(cov(ker))[], pi),
AMP.evaluate(M, ker, [pi])
)
@test_broken isapprox(
pdf_wrapped_normal(mean(ker)[], sqrt(cov(ker))[], 0),
AMP.evaluate(M, ker, [0.])
)

##
M = SpecialEuclidean(2)
ϵ = identity_element(M)
Xc = [10, 20, 0.1]
p = exp(M, ϵ, hat(M, ϵ, Xc))
kercov = diagm([0.5, 2.0, 0.1].^2)
ker = AMP.MvNormalKernel(p, kercov)
@test isapprox(
AMP.evaluate(M, ker, p),
pdf(MvNormal(Xc, cov(ker)), Xc)
)

Xc = [10, 22, -0.1]
q = exp(M, ϵ, hat(M, ϵ, Xc))

@test isapprox(
pdf(MvNormal(cov(ker)), [0,0,0]),
AMP.evaluate(M, ker, p)
)

X = log(M, ϵ, Manifolds.compose(M, inv(M, p), q))
Xc_e = vee(M, ϵ, X)
pdf_local_coords = pdf(MvNormal(cov(ker)), Xc_e)

@test isapprox(
pdf_local_coords,
AMP.evaluate(M, ker, q),
)

delta_c = AMP.distanceMalahanobisCoordinates(M, ker, q)
X = log(M, ϵ, Manifolds.compose(M, inv(M, p), q))
Xc_e = vee(M, ϵ, X)
malad_t = Xc_e'*inv(kercov)*Xc_e
# delta_t = [10, 20, 0.1] - [10, 22, -0.1]
@test isapprox(
malad_t,
delta_c'*delta_c;
atol=1e-10
)

malad2 = AMP.distanceMalahanobisSq(M,ker,q)
@test isapprox(
malad_t,
malad2;
atol=1e-10
)

rbfd = AMP.ker(M, ker, q, 0.5, AMP.distanceMalahanobisSq)
@test isapprox(
exp(-0.5*malad_t),
rbfd;
atol=1e-10
)


# NOTE 'global' distribution would have been
X = log(M, mean(ker), q)
Xc_e = vee(M, ϵ, X)
pdf_global_coords = pdf(MvNormal(cov(ker)), Xc_e)


end

@testset "Basic ManellicTree manifolds construction and evaluations" begin
##


M = TranslationGroup(1)
ϵ = identity_element(M)
dis = MvNormal([3.0], diagm([1.0].^2))
Cpts = [rand(dis) for _ in 1:128]
pts = map(c->exp(M, ϵ, hat(M, ϵ, c)), Cpts)
mtree = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw = [0.2;;], kernel=AMP.MvNormalKernel)

##
p = exp(M, ϵ, hat(M, ϵ, [3.0]))
y_amp = AMP.evaluate(mtree, p)

y_pdf = pdf(dis, [3.0])

@test isapprox(y_amp, y_pdf; atol=0.1)

# ps = [[p] for p = -0:0.01:6]
# ys_amp = map(p->AMP.evaluate(mtree, exp(M, ϵ, hat(M, ϵ, p))), ps)
# ys_pdf = pdf(dis, ps)

# lines(first.(ps), ys_pdf)
# lines!(first.(ps), ys_amp)

# lines!(first.(ps), ys_pdf)
# lines(first.(ps), ys_amp)
##

M = SpecialOrthogonal(2)
ϵ = identity_element(M)
dis = MvNormal([0.0], diagm([0.1].^2))
Cpts = [rand(dis) for _ in 1:128]
pts = map(c->exp(M, ϵ, hat(M, ϵ, c)), Cpts)
mtree = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw = [0.005;;], kernel=AMP.MvNormalKernel)

##
p = exp(M, ϵ, hat(M, ϵ, [0.1]))
y_amp = AMP.evaluate(mtree, p)

y_pdf = pdf(dis, [0.1])

@test isapprox(y_amp, y_pdf; atol=0.5)

ps = [[p] for p = -0.3:0.01:0.3]
ys_amp = map(p->AMP.evaluate(mtree, exp(M, ϵ, hat(M, ϵ, p))), ps)
ys_pdf = pdf(dis, ps)

# lines(first.(ps), ys_pdf)
# lines!(first.(ps), ys_amp)


M = SpecialEuclidean(2)
ϵ = identity_element(M)
dis = MvNormal([10,20,0.1], diagm([0.5,2.0,0.1].^2))
Cpts = [rand(dis) for _ in 1:128]
pts = map(c->exp(M, ϵ, hat(M, ϵ, c)), Cpts)
mtree = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw = diagm([0.05,0.2,0.01]), kernel=AMP.MvNormalKernel)

##
p = exp(M, ϵ, hat(M, ϵ, [10, 20, 0.1]))
y_amp = AMP.evaluate(mtree, p)
y_pdf = pdf(dis, [10,20,0.1])
#FIXME
@test_broken isapprox(y_amp, y_pdf; atol=0.1)

end


## ========================================================================================
@testset "Test Product the brute force way" begin

M = SpecialEuclidean(2)
ϵ = identity_element(M)

Xc_p = [10, 20, 0.1]
p = exp(M, ϵ, hat(M, ϵ, Xc_p))
kerp = AMP.MvNormalKernel(p, diagm([0.5, 2.0, 0.1].^2))

Xc_q = [10, 22, -0.1]
q = exp(M, ϵ, hat(M, ϵ, Xc_q))
kerq = AMP.MvNormalKernel(q, diagm([1.0, 1.0, 0.1].^2))

kerpq = calcProductGaussians(M, [kerp, kerq])

# brute force way
xs = 7:0.1:13
ys = 15:0.1:27
θs = -0.3:0.01:0.3

grid_points = map(Iterators.product(xs, ys, θs)) do (x,y,θ)
exp(M, ϵ, hat(M, ϵ, SVector(x,y,θ)))
end

# use_global_coords = true
use_global_coords = false

pdf_ps = map(grid_points) do gp
if use_global_coords
X = log(M, p, gp)
Xc_e = vee(M, ϵ, X)
pdf(MvNormal(cov(kerp)), Xc_e)
else
X = log(M, ϵ, Manifolds.compose(M, inv(M, p), gp))
Xc_e = vee(M, ϵ, X)
pdf(MvNormal(cov(kerp)), Xc_e)
end
end

pdf_qs = map(grid_points) do gp
if use_global_coords
X = log(M, q, gp)
Xc_e = vee(M, ϵ, X)
pdf(MvNormal(cov(kerq)), Xc_e)
else
X = log(M, ϵ, Manifolds.compose(M, inv(M, q), gp))
Xc_e = vee(M, ϵ, X)
pdf(MvNormal(cov(kerq)), Xc_e)
end
end

pdf_pqs = pdf_ps .* pdf_qs
# pdf_pqs ./= sum(pdf_pqs) * 0.01
# pdf_pqs .*= 15.9672

amp_ps = map(grid_points) do gp
AMP.evaluate(M, kerp, gp)
end
amp_qs = map(grid_points) do gp
AMP.evaluate(M, kerq, gp)
end

amp_pqs = map(grid_points) do gp
AMP.evaluate(M, kerpq, gp)
end

amp_bf_pqs = amp_ps .* amp_qs

#FIXME
normalized_compare_test = isapprox.(normalize(amp_pqs), normalize(amp_bf_pqs); atol=0.001)
@test_broken all(normalized_compare_test)
@warn "Brute force product test overlap $(round(count(normalized_compare_test) / length(amp_pqs) * 100, digits=2))%"

#TODO should this be local or global coords?
@test_broken findmax(pdf_pqs[:,60,30])[2] == findmax(amp_pqs[:,60,30])[2]

# these are all correct
# lines(xs, pdf_ps[:,60,30])
# lines!(xs, amp_ps[:,60,30])

# lines(ys, pdf_ps[30,:,30])
# lines!(ys, amp_ps[30,:,30])

# lines(θs, pdf_ps[30,60,:])
# lines!(θs, amp_ps[30,60,:])

#these are different for "local" vs "global"
# lines(xs, normalize(pdf_pqs[:,60,30]))
# lines!(xs, normalize(amp_pqs[:,60,30]))
# lines!(xs, normalize(amp_bf_pqs[:,60,30]))

# lines(ys, normalize(pdf_pqs[30,:,30]))
# lines!(ys, normalize(amp_pqs[30,:,30]))

# lines(θs, normalize(pdf_pqs[30,60,:]))
# lines!(θs, normalize(amp_pqs[30,60,:]))

# contour(xs, ys, pdf_pqs[:,:,30]; color = :blue)
# contour!(xs, ys, amp_pqs[:,:,30]; color = :red)
# contour!(xs, ys, amp_bf_pqs[:,:,30]; color = :green)

#just some exploration
# pdf_p = pdf(Normal(10, 0.5), xs)
# pdf_q = pdf(Normal(10, 1.0), xs)
# pdf_pq = (pdf_p .* pdf_q)
# pdf_pq ./= sum(pdf_pq) * 0.01

# lines(xs, pdf_p)
# lines!(xs, pdf_q)
# lines!(xs, pdf_pq)

# pdf_p = pdf(Normal(20, 2.0), ys)
# pdf_q = pdf(Normal(22, 1.0), ys)
# pdf_pq = (pdf_p .* pdf_q)
# pdf_pq ./= sum(pdf_pq) * 0.01

# lines(ys, pdf_p)
# lines!(ys, pdf_q)
# lines!(ys, pdf_pq)

# pdf_p = pdf(Normal(0.1, 0.1), θs)
# pdf_q = pdf(Normal(-0.1, 0.1), θs)
# pdf_pq = (pdf_p .* pdf_q)
# pdf_pq ./= sum(pdf_pq) * 0.01

# lines(θs, pdf_p)
# lines!(θs, pdf_q)
# lines!(θs, pdf_pq)


end

## ========================================================================================


@testset "Manellic basic evaluation test 1D" begin
##

Expand All @@ -274,7 +584,7 @@ pts = [zeros(1) for _ in 1:100]
bw = ones(1,1)
mtree = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw=bw, kernel=AMP.MvNormalKernel)

@test isapprox( 0.4, AMP.evaluate(mtree, SA[0.0;]); atol=0.1)
@test isapprox( pdf(Normal(0,1), 0), AMP.evaluate(mtree, SA[0.0;]))

@error "expectedLogL for different number of test points not working yet."
# AMP.expectedLogL(mtree, [randn(1) for _ in 1:5])
Expand Down

0 comments on commit c441bc1

Please sign in to comment.