Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev example of maniballtree #213

Closed
wants to merge 14 commits into from
11 changes: 6 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
KernelDensityEstimate = "2472808a-b354-52ea-a80e-1658a3c6056d"
LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand All @@ -18,11 +19,11 @@ Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
Expand All @@ -42,21 +43,21 @@ Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"
ApproxManiProdGadflyExt = "Gadfly"

[compat]
CoordinateTransformations = "0.5, 0.6"
DocStringExtensions = "0.7, 0.8, 0.9"
CoordinateTransformations = "0.6"
DocStringExtensions = "0.8, 0.9"
Distances = "0.10.7"
KernelDensityEstimate = "0.5.10"
Manifolds = "0.8"
ManifoldsBase = "0.13, 0.14"
NLsolve = "3, 4"
Optim = "1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1.0"
Requires = "0.5, 1"
Rotations = "1"
StaticArrays = "0.15, 1"
TensorCast = "0.2, 0.3, 0.4"
TransformUtils = "0.2.10"
julia = "1.4"
julia = "1.6"

[extras]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Expand Down
41 changes: 41 additions & 0 deletions examples/ManiBallTree.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@


# using Revise

using Manifolds, NearestNeighbors, Distances, StaticArrays
import Manifolds: ArrayPartition

##

Base.@kwdef struct DistSE2 <: Distances.Metric
M = SpecialEuclidean(2)
end
# (d::DistSE2)(p,q) = Manifolds.ManifoldsBase.distance(d.M,p,q)
(d::DistSE2)(p::ArrayPartition,q::ArrayPartition) = Manifolds.ManifoldsBase.distance(d.M,p,q)
function (d::DistSE2)(p::ArrayPartition,q_::AbstractVector)
q = ArrayPartition(SA[q_[1],q_[2]], SMatrix{2,2}(q_[3],q_[4],q_[5],q_[6]))
Manifolds.ManifoldsBase.distance(d.M,p,q)
end
function (d::DistSE2)(p_::AbstractVector,q_::AbstractVector)
p = ArrayPartition(SA[p_[1],p_[2]], SMatrix{2,2}(p_[3],p_[4],p_[5],p_[6]))
q = ArrayPartition(SA[q_[1],q_[2]], SMatrix{2,2}(q_[3],q_[4],q_[5],q_[6]))
Manifolds.ManifoldsBase.distance(d.M,p,q)
end

# see https://github.com/SciML/RecursiveArrayTools.jl/pull/220
Base.length(::Type{<:ArrayPartition{F,T}}) where {F,N,T <: NTuple{N,StaticArray}} = T.parameters .|> length |> sum

##

pts = [ArrayPartition(SA[randn(2)...], SMatrix{2,2}([1 0; 0 1.])) for _ in 1:10]
dSE2 = DistSE2()

##

bt = NearestNeighbors.BallTree(pts, dSE2; leafsize=1)

## How to get any node in the BallTree



##
20 changes: 17 additions & 3 deletions src/ApproxManifoldProducts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ import Rotations as _Rot

import ManifoldsBase
import ManifoldsBase: AbstractManifold
using RecursiveArrayTools: ArrayPartition
export ArrayPartition
# using RecursiveArrayTools: ArrayPartition

using Manifolds

Expand All @@ -26,11 +25,18 @@ using StaticArrays
using Logging
using Statistics

import Random: rand
import NearestNeighbors as NNR
import Distances as DST


import Base: *, isapprox, convert
import LinearAlgebra: rotate!
import Statistics: mean, std, cov, var
import Random: rand

import NearestNeighbors: TreeData, NNTree
import Manifolds: ArrayPartition

import KernelDensityEstimate: getPoints, getBW

const MB = ManifoldsBase
Expand All @@ -41,6 +47,8 @@ const KDE = KernelDensityEstimate
# TODO temporary for initial version of on-manifold products
KDE.setForceEvalDirect!(true)

export ArrayPartition

# the exported API
include("ExportAPI.jl")

Expand All @@ -54,6 +62,12 @@ include("Legacy.jl")
include("services/ManifoldPartials.jl")
include("Interface.jl")

# Experimental ManifoldBallTreeBalanced
include("services/TreeDataBalanced.jl")
include("services/ManifoldTreeOps.jl")
include("services/ManifoldHyperSpheres.jl")
include("services/ManifoldBallTree.jl")

# regular features
include("CommonUtils.jl")
include("services/ManifoldKernelDensity.jl")
Expand Down
2 changes: 0 additions & 2 deletions src/Deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
# end


@deprecate R(th::Real) _Rot.RotMatrix2(th).mat # = [[cos(th);-sin(th)]';[sin(th);cos(th)]'];
@deprecate R(;x::Real=0.0,y::Real=0.0,z::Real=0.0) (M=SpecialOrthogonal(3);exp(M,identity_element(M),hat(M,Identity(M),[x,y,z]))) # convert(SO3, so3([x,y,z]))

export calcCovarianceBasic
# Returns the covariance (square), not deviation
Expand Down
2 changes: 1 addition & 1 deletion src/Legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function buildHybridManifoldCallbacks(manif::Tuple)
end

# FIXME TO BE REMOVED
_MtoSymbol(::Euclidean{Tuple{1}}) = :Euclid
_MtoSymbol(::Manifolds.Euclidean{Tuple{1}}) = :Euclid
_MtoSymbol(::Circle) = :Circular
Base.convert(::Type{<:Tuple}, M::ProductManifold) = _MtoSymbol.(M.manifolds)
Base.convert(::Type{<:Tuple}, M::TranslationGroup) = tuple([:Euclid for i in 1:manifold_dimension(M)]...)
Expand Down
Loading
Loading