diff --git a/test/manellic/testManellicTree.jl b/test/manellic/testManellicTree.jl index 175f54d..3675ae2 100755 --- a/test/manellic/testManellicTree.jl +++ b/test/manellic/testManellicTree.jl @@ -869,12 +869,14 @@ end @show best_cov = abs.(Optim.minimizer(res)) -@test isapprox([0.6; 0.6; 0.06], best_cov; atol=0.35) +@test isapprox([0.6; 0.6], best_cov[1:2]; atol=0.35) +@test isapprox(0.06, best_cov[3]; atol=0.04) mkd = ApproxManifoldProducts.manikde!_manellic(M,pts) -@test isapprox([0.6 0 0; 0 0.6 0; 0 0 0.06], getBW(mkd)[1]; atol=0.35) +@test isapprox([0.6 0; 0 0.6], getBW(mkd)[1][1:2,1:2]; atol=0.35) +@test isapprox(0.06, getBW(mkd)[1][3,3]; atol=0.04) ## @@ -882,6 +884,42 @@ end +@testset "Multidimensional LOOCV bandwidth optimization, SpecialEuclidean(3)" begin +## + +M = SpecialEuclidean(3) +pts = [ArrayPartition(SA[randn(3)...;],SMatrix{3,3,Float64}(collect(Rot_.RotXYZ(0.1*randn(3)...)))) for _ in 1:64] + +bw = SA[1.0; 1.0; 1.0; 0.3; 0.3; 0.3] +mtree = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw=bw,kernel=AMP.MvNormalKernel) + +cost4(σ) = begin + AMP.entropy(mtree, diagm(σ.^2)) +end + +# and optimize with "update" kernel bandwith cost +@time res = Optim.optimize( + cost4, + collect(bw), + Optim.NelderMead() +); + +@test res.ls_success + +@show best_cov = abs.(Optim.minimizer(res)) + +@test isapprox([0.75; 0.75; 0.75], best_cov[1:3]; atol=0.4) +@test isapprox([0.06; 0.06; 0.06], best_cov[4:6]; atol=0.04) + + +mkd = ApproxManifoldProducts.manikde!_manellic(M,pts) + +@test isapprox([0.75 0 0; 0 0.75 0; 0 0 0.75], getBW(mkd)[1][1:3,1:3]; atol=0.4) +@test isapprox([0.06 0 0; 0 0.06 0; 0 0 0.06], getBW(mkd)[1][4:6,4:6]; atol=0.04) + +## +end + ## # # using GLMakie @@ -1064,4 +1102,4 @@ end # lines!((s->s[1]).(XX),YY, color=:red) -# \ No newline at end of file +#