diff --git a/src/services/ManifoldKernelDensity.jl b/src/services/ManifoldKernelDensity.jl index 998f7c2..49d0010 100644 --- a/src/services/ManifoldKernelDensity.jl +++ b/src/services/ManifoldKernelDensity.jl @@ -107,7 +107,7 @@ manikde!( function manikde!_manellic( M::AbstractManifold, pts::AbstractVector; - bw=ones(manifold_dimension(M),1), + bw=diagm(ones(manifold_dimension(M))), ) # @@ -119,21 +119,29 @@ function manikde!_manellic( ) # Cost function to optimize - _cost(_pts, σ) = begin - # avoid rebuilding tree at each optim iteration!!! - # mtr = buildTree_Manellic!(M, _pts; kernel_bw=reshape(σ,manifold_dimension(M),1), kernel=MvNormalKernel) - entropy(mtree,reshape(σ,manifold_dimension(M),1)) - end + # avoid rebuilding tree at each optim iteration!!! + _cost(σ::Real) = entropy(mtree,[σ^2;;]) # reshape(σ,manifold_dimension(M),1)) + _cost(σ::AbstractVector) = entropy(mtree,diagm(σ.^2)) # reshape(σ,manifold_dimension(M),1)) + _cost(σ::AbstractMatrix) = entropy(mtree,σ.^2) # reshape(σ,manifold_dimension(M),1)) # optimize for best LOOCV bandwidth # FIXME switch to RLM (or other Manopt) techinque instead # set lower and upper bounds for Golden section optimization - lcov, ucov = getBandwidthSearchBounds(mtree) - res = Optim.optimize( - (s)->_cost(pts,[s^2;]), - lcov[1], ucov[1], Optim.GoldenSection() - ) - best_cov = [Optim.minimizer(res);;] + best_cov = if 1 === manifold_dimension(M) + lcov, ucov = getBandwidthSearchBounds(mtree) + res = Optim.optimize( + (s)->_cost([s;]), + lcov[1], ucov[1], Optim.GoldenSection() + ) + [Optim.minimizer(res);;] + else + res = Optim.optimize( + _cost, + diag(bw), # FIXME Optim API issue, if using bw::matrix then steps not PDMat (NelderMead) + Optim.NelderMead() + ) + diagm(Optim.minimizer(res)) + end # reuse (heavy lift parts of) earlier tree build # return tree with correct bandwidth diff --git a/test/manellic/testManellicTree.jl b/test/manellic/testManellicTree.jl index 93d3506..9c9c1e7 100755 --- a/test/manellic/testManellicTree.jl +++ b/test/manellic/testManellicTree.jl @@ -732,6 +732,7 @@ res = Optim.optimize( (s)->cost(pts,s^2), 0.05, 3.0, Optim.GoldenSection() ) + best_cov = Optim.minimizer(res) @test isapprox(0.5, best_cov; atol=0.3) @@ -749,6 +750,7 @@ res = Optim.optimize( (s)->cost2(s^2), 0.05, 3.0, Optim.GoldenSection() ) + @show best_cov = Optim.minimizer(res) @test isapprox(bcov_, best_cov; atol=1e-3) @@ -764,7 +766,8 @@ res = Optim.optimize( (s)->cost3(s^2), 0.05, 3.0, Optim.GoldenSection() ) -@show best_cov = Optim.minimizer(res) + +best_cov = Optim.minimizer(res) @test isapprox(bcov_, best_cov; atol=1e-3) @@ -804,6 +807,41 @@ end end +@testset "Multidimensional LOOCV bandwidth optimization" begin +## + +M = TranslationGroup(2) +pts = [1*randn(2) for _ in 1:64] + +bw = [1.0; 1.0] +mtree = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw=bw,kernel=AMP.MvNormalKernel) + +cost4(σ) = begin + AMP.entropy(mtree, diagm(σ.^2)) +end + +# and optimize with "update" kernel bandwith cost +@time res = Optim.optimize( + cost4, + bw, + Optim.NelderMead() +); + +@test res.ls_success + +@show best_cov = Optim.minimizer(res) + +@test isapprox([0.5; 0.5], best_cov; atol=0.3) + + +mkd = ApproxManifoldProducts.manikde!_manellic(M,pts) + +@test isapprox([0.5 0; 0 0.5], getBW(mkd)[1]; atol=0.3) + + +## +end + ## # # using GLMakie