Skip to content

Commit

Permalink
cleanup test
Browse files Browse the repository at this point in the history
  • Loading branch information
Affie committed May 3, 2024
1 parent 9d74603 commit b296c67
Showing 1 changed file with 44 additions and 45 deletions.
89 changes: 44 additions & 45 deletions test/manellic/testManellicTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ permref = sortperm(pts, by=s->getindex(s,1))
@test norm((pts[mtree.permute] .- mean.(mtree.leaf_kernels)) .|> s->s[1]) < 1e-6



# for (i,v) in enumerate(dict[:evaltest_1_at])
# # @show AMP.evaluate(mtree, [v;]), dict[:evaltest_1_dens][i]
# @test isapprox(dict[:evaltest_1_dens][i], AMP.evaluate(mtree, [v;]))
Expand All @@ -267,11 +266,7 @@ permref = sortperm(pts, by=s->getindex(s,1))
##
end

## ========================================================================================
## ========================================================================================
## ========================================================================================

# @testset "Test evaluate MvNormalKernel" begin
@testset "Test evaluate MvNormalKernel" begin

M = TranslationGroup(1)
ker = AMP.MvNormalKernel([0.0], [0.5;;])
Expand Down Expand Up @@ -373,9 +368,9 @@ Xc_e = vee(M, ϵ, X)
pdf_global_coords = pdf(MvNormal(cov(ker)), Xc_e)


end


# @testset "ManellicTree SE2 basic construction and evaluations" begin
@testset "Basic ManellicTree manifolds construction and evaluations" begin
##


Expand All @@ -394,7 +389,7 @@ y_pdf = pdf(dis, [3.0])

@test isapprox(y_amp, y_pdf; atol=0.1)


if false
ps = [[p] for p = -0:0.01:6]
ys_amp = map(p->AMP.evaluate(mtree, exp(M, ϵ, hat(M, ϵ, p))), ps)
ys_pdf = pdf(dis, ps)
Expand All @@ -404,10 +399,9 @@ lines!(first.(ps), ys_amp)

lines!(first.(ps), ys_pdf)
lines(first.(ps), ys_amp)

end
##


M = SpecialOrthogonal(2)
ϵ = identity_element(M)
dis = MvNormal([0.0], diagm([0.1].^2))
Expand All @@ -421,16 +415,16 @@ y_amp = AMP.evaluate(mtree, p)

y_pdf = pdf(dis, [0.1])

@test isapprox(y_amp, y_pdf; atol=0.1)
@test isapprox(y_amp, y_pdf; atol=0.5)

ps = [[p] for p = -0.3:0.01:0.3]
ys_amp = map(p->AMP.evaluate(mtree, exp(M, ϵ, hat(M, ϵ, p))), ps)
ys_pdf = pdf(dis, ps)

if false
lines(first.(ps), ys_pdf)
lines!(first.(ps), ys_amp)


end


M = SpecialEuclidean(2)
Expand All @@ -447,8 +441,11 @@ y_pdf = pdf(dis, [10,20,0.1])
#FIXME
@test_broken isapprox(y_amp, y_pdf; atol=0.1)

end


##
## ========================================================================================
@testset "Test Product the brute force way" begin

M = SpecialEuclidean(2)
ϵ = identity_element(M)
Expand All @@ -457,12 +454,10 @@ Xc_p = [10, 20, 0.1]
p = exp(M, ϵ, hat(M, ϵ, Xc_p))
kerp = AMP.MvNormalKernel(p, diagm([0.5, 2.0, 0.1].^2))


Xc_q = [10, 22, -0.1]
q = exp(M, ϵ, hat(M, ϵ, Xc_q))
kerq = AMP.MvNormalKernel(q, diagm([1.0, 1.0, 0.1].^2))


kerpq = calcProductGaussians(M, [kerp, kerq])

# brute force way
Expand All @@ -474,7 +469,9 @@ grid_points = map(Iterators.product(xs, ys, θs)) do (x,y,θ)
exp(M, ϵ, hat(M, ϵ, SVector(x,y,θ)))
end

# use_global_coords = true
use_global_coords = false

pdf_ps = map(grid_points) do gp
if use_global_coords
X = log(M, p, gp)
Expand All @@ -501,7 +498,7 @@ end

pdf_pqs = pdf_ps .* pdf_qs
# pdf_pqs ./= sum(pdf_pqs) * 0.01
pdf_pqs .*= 15.9672
# pdf_pqs .*= 15.9672

amp_ps = map(grid_points) do gp
AMP.evaluate(M, kerp, gp)
Expand All @@ -514,6 +511,18 @@ amp_pqs = map(grid_points) do gp
AMP.evaluate(M, kerpq, gp)
end

amp_bf_pqs = amp_ps .* amp_qs

#FIXME
normalized_compare_test = isapprox.(normalize(amp_pqs), normalize(amp_bf_pqs); atol=0.001)
@test_broken all(normalized_compare_test)
@warn "Brute force product test overlap $(round(count(normalized_compare_test) / length(amp_pqs) * 100, digits=2))%"

#TODO should this be local or global coords?
@test_broken findmax(pdf_pqs[:,60,30])[2] == findmax(amp_pqs[:,60,30])[2]

if false
# these are all correct
lines(xs, pdf_ps[:,60,30])
lines!(xs, amp_ps[:,60,30])

Expand All @@ -522,21 +531,27 @@ lines!(ys, amp_ps[30,:,30])

lines(θs, pdf_ps[30,60,:])
lines!(θs, amp_ps[30,60,:])
end

if false
#these are different for "local" vs "global"
lines(xs, normalize(pdf_pqs[:,60,30]))
lines!(xs, normalize(amp_pqs[:,60,30]))
lines!(xs, normalize(amp_bf_pqs[:,60,30]))

lines(xs, pdf_pqs[:,60,30])
lines!(xs, amp_pqs[:,60,30])

lines(ys, pdf_pqs[30,:,30])
lines!(ys, amp_pqs[30,:,30])

lines(θs, pdf_pqs[30,60,:])
lines!(θs, amp_pqs[30,60,:])
lines(ys, normalize(pdf_pqs[30,:,30]))
lines!(ys, normalize(amp_pqs[30,:,30]))

contour(xs, ys, pdf_pqs[:,:,30])
contour!(xs, ys, amp_pqs[:,:,30])
lines(θs, normalize(pdf_pqs[30,60,:]))
lines!(θs, normalize(amp_pqs[30,60,:]))

contour(xs, ys, pdf_pqs[:,:,30]; color = :blue)
contour!(xs, ys, amp_pqs[:,:,30]; color = :red)
contour!(xs, ys, amp_bf_pqs[:,:,30]; color = :green)
end

if false
#just some exploration
pdf_p = pdf(Normal(10, 0.5), xs)
pdf_q = pdf(Normal(10, 1.0), xs)
pdf_pq = (pdf_p .* pdf_q)
Expand Down Expand Up @@ -564,27 +579,11 @@ pdf_pq ./= sum(pdf_pq) * 0.01
lines(θs, pdf_p)
lines!(θs, pdf_q)
lines!(θs, pdf_pq)


## ========================================================================================


X = log(M, mean(ker), q)
Xc_e = vee(M, ϵ, X)
pdf(MvNormal(cov(ker)), Xc_e)
# 0.05211875018288499
# ^ "global" vs "local" v
X = log(M, ϵ, Manifolds.compose(M, inv(M, p), q))
Xc_e = vee(M, ϵ, X)
pdf(MvNormal(cov(ker)), Xc_e)
# 0.0483649046065308

end

end

## ========================================================================================
## ========================================================================================
## ========================================================================================


@testset "Manellic basic evaluation test 1D" begin
Expand Down

0 comments on commit b296c67

Please sign in to comment.