Skip to content

Commit

Permalink
Simplify random point/tangent generation by using default_rng (#605)
Browse files Browse the repository at this point in the history
* Simplify random point/tangent generation by using default_rng

* fix

* simplify

* move fallback to a separate file

* remove another unnecessary method

* bump version
  • Loading branch information
mateuszbaran authored May 17, 2023
1 parent c0334d4 commit 440519b
Show file tree
Hide file tree
Showing 22 changed files with 10 additions and 233 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Manifolds"
uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
authors = ["Seth Axen <[email protected]>", "Mateusz Baran <[email protected]>", "Ronny Bergmann <[email protected]>", "Antoine Levitt <[email protected]>"]
version = "0.8.60"
version = "0.8.61"

[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Expand Down
2 changes: 2 additions & 0 deletions src/Manifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ include("utils.jl")

include("product_representations.jl")

include("manifold_fallbacks.jl")

# Main Meta Manifolds
include("manifolds/ConnectionManifold.jl")
include("manifolds/MetricManifold.jl")
Expand Down
4 changes: 4 additions & 0 deletions src/manifold_fallbacks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

function Random.rand!(M::AbstractManifold, pX; kwargs...)
return rand!(Random.default_rng(), M, pX; kwargs...)
end
4 changes: 0 additions & 4 deletions src/manifolds/Circle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,6 @@ function Random.rand(rng::AbstractRNG, M::Circle{ℂ}; vector_at=nothing, σ::Re
end
end

function Random.rand!(M::Circle{ℝ}, pX; vector_at=nothing, σ::Real=one(eltype(pX)))
pX .= rand(M; vector_at, σ)
return pX
end
function Random.rand!(
rng::AbstractRNG,
M::Circle{ℝ},
Expand Down
4 changes: 0 additions & 4 deletions src/manifolds/Euclidean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -571,10 +571,6 @@ project(::Euclidean{Tuple{}}, ::Number, X::Number) = X

project!(::Euclidean, Y, p, X) = copyto!(Y, X)

function Random.rand!(::Euclidean, pX; σ=one(eltype(pX)), vector_at=nothing)
pX .= randn(eltype(pX), size(pX)) .* σ
return pX
end
function Random.rand!(
rng::AbstractRNG,
::Euclidean,
Expand Down
48 changes: 2 additions & 46 deletions src/manifolds/FixedRankMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -444,26 +444,8 @@ and the singular values are sampled uniformly at random.
If `vector_at` is not `nothing`, generate a random tangent vector in the tangent space of
the point `vector_at` on the `FixedRankMatrices` manifold `M`.
"""
function Random.rand(
M::FixedRankMatrices{m,n,k};
vector_at=nothing,
kwargs...,
) where {m,n,k}
if vector_at === nothing
p = SVDMPoint(
Matrix{Float64}(undef, m, k),
Vector{Float64}(undef, k),
Matrix{Float64}(undef, k, n),
)
return rand!(M, p; kwargs...)
else
X = UMVTVector(
Matrix{Float64}(undef, m, k),
Matrix{Float64}(undef, k, k),
Matrix{Float64}(undef, k, n),
)
return rand!(M, X; vector_at, kwargs...)
end
function Random.rand(M::FixedRankMatrices; vector_at=nothing, kwargs...)
return rand(Random.default_rng(), M; vector_at=vector_at, kwargs...)
end
function Random.rand(
rng::AbstractRNG,
Expand All @@ -488,32 +470,6 @@ function Random.rand(
end
end

function Random.rand!(
M::FixedRankMatrices{m,n,k},
pX;
vector_at=nothing,
kwargs...,
) where {m,n,k}
if vector_at === nothing
U = rand(Stiefel(m, k); kwargs...)
S = sort(rand(k); rev=true)
V = rand(Stiefel(n, k); kwargs...)
copyto!(M, pX, SVDMPoint(U, S, V'))
else
Up = randn(m, k)
Vp = randn(n, k)
A = randn(k, k)
copyto!(
pX,
UMVTVector(
Up - vector_at.U * vector_at.U' * Up,
A,
Vp' - Vp' * vector_at.Vt' * vector_at.Vt,
),
)
end
return pX
end
function Random.rand!(
rng::AbstractRNG,
::FixedRankMatrices{m,n,k},
Expand Down
7 changes: 0 additions & 7 deletions src/manifolds/FlagOrthogonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,6 @@ function project(M::Flag{N,dp1}, ::OrthogonalPoint, X::OrthogonalTVector) where
return OrthogonalTVector(Y)
end

function Random.rand!(
M::Flag,
pX::Union{OrthogonalPoint,OrthogonalTVector};
vector_at=nothing,
)
return rand!(Random.default_rng(), M, pX; vector_at=vector_at)
end
function Random.rand!(
rng::AbstractRNG,
M::Flag{N,dp1},
Expand Down
3 changes: 0 additions & 3 deletions src/manifolds/FlagStiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,6 @@ function project!(M::Flag, q::AbstractMatrix, p::AbstractMatrix)
return project!(get_embedding(M), q, p)
end

function Random.rand!(M::Flag, pX::AbstractMatrix; vector_at=nothing)
return rand!(Random.default_rng(), M, pX; vector_at=vector_at)
end
function Random.rand!(
rng::AbstractRNG,
M::Flag{N,dp1},
Expand Down
3 changes: 0 additions & 3 deletions src/manifolds/GeneralizedGrassmann.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,6 @@ random Matrix onto the tangent vector at `vector_at`.
"""
rand(::GeneralizedGrassmann; σ::Real=1.0)

function Random.rand!(M::GeneralizedGrassmann{n,k,ℝ}, pX; kwargs...) where {n,k}
return Random.rand!(Random.default_rng(), M, pX; kwargs...)
end
function Random.rand!(
rng::AbstractRNG,
M::GeneralizedGrassmann{n,k,ℝ},
Expand Down
3 changes: 0 additions & 3 deletions src/manifolds/GeneralizedStiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,6 @@ random Matrix onto the tangent vector at `vector_at`.
"""
rand(::GeneralizedStiefel; σ::Real=1.0)

function Random.rand!(M::GeneralizedStiefel{n,k,ℝ}, pX; kwargs...) where {n,k}
return Random.rand!(Random.default_rng(), M, pX; kwargs...)
end
function Random.rand!(
rng::AbstractRNG,
M::GeneralizedStiefel{n,k,ℝ},
Expand Down
4 changes: 0 additions & 4 deletions src/manifolds/GrassmannStiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,6 @@ Matrix onto the tangent space at `vector_at`.
"""
rand(M::Grassmann; σ::Real=1.0)

function Random.rand!(M::Grassmann, pX; kwargs...)
Random.rand!(Random.default_rng(), M, pX; kwargs...)
return pX
end
function Random.rand!(
rng::AbstractRNG,
M::Grassmann{n,k,𝔽},
Expand Down
3 changes: 0 additions & 3 deletions src/manifolds/HyperbolicHyperboloid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,6 @@ function project!(::Hyperbolic, Y::HyperboloidTVector, p::HyperboloidPoint, X)
return (Y.value .= X .+ minkowski_metric(p.value, X) .* p.value)
end

function Random.rand!(M::Hyperbolic, pX; vector_at=nothing, σ=one(eltype(pX)))
return rand!(Random.default_rng(), M, pX; vector_at=vector_at, σ=σ)
end
function Random.rand!(
rng::AbstractRNG,
M::Hyperbolic{N},
Expand Down
9 changes: 0 additions & 9 deletions src/manifolds/KendallsPreShapeSpace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,6 @@ function project!(::KendallsPreShapeSpace, Y, p, X)
return Y
end

function Random.rand!(M::KendallsPreShapeSpace, pX; vector_at=nothing, σ=one(eltype(pX)))
if vector_at === nothing
project!(M, pX, randn(representation_size(M)))
else
n = σ * randn(size(pX)) # Gaussian in embedding
project!(M, pX, vector_at, n) # project to TpM (keeps Gaussianness)
end
return pX
end
function Random.rand!(
rng::AbstractRNG,
M::KendallsPreShapeSpace,
Expand Down
4 changes: 0 additions & 4 deletions src/manifolds/KendallsShapeSpace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,6 @@ with mean zero and standard deviation `σ`.
"""
rand(::KendallsShapeSpace; σ::Real=1.0)

function Random.rand!(M::KendallsShapeSpace{n,k}, pX; vector_at=nothing) where {n,k}
rand!(get_embedding(M), pX; vector_at=vector_at)
return pX
end
function Random.rand!(
rng::AbstractRNG,
M::KendallsShapeSpace{n,k},
Expand Down
17 changes: 0 additions & 17 deletions src/manifolds/Orthogonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,6 @@ const OrthogonalMatrices{n} = GeneralUnitaryMatrices{n,ℝ,AbsoluteDeterminantOn

OrthogonalMatrices(n) = OrthogonalMatrices{n}()

function Random.rand!(M::OrthogonalMatrices, pX; vector_at=nothing, σ::Real=one(eltype(pX)))
if vector_at === nothing
# Special case: Rotations(1) is just zero-dimensional
(manifold_dimension(M) == 0) && return fill!(pX, rand((-1, 1)))
A = randn(representation_size(M))
s = diag(sign.(qr(A).R))
D = Diagonal(s)
pX .= qr(A).Q * D
else
# Special case: Rotations(1) is just zero-dimensional
(manifold_dimension(M) == 0) && return fill!(pX, 0)
A = σ .* randn(representation_size(M))
pX .= triu(A, 1) .- transpose(triu(A, 1))
normalize!(pX)
end
return pX
end
function Random.rand!(
rng::AbstractRNG,
M::OrthogonalMatrices,
Expand Down
8 changes: 0 additions & 8 deletions src/manifolds/PositiveNumbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,6 @@ function parallel_transport_to!(::PositiveNumbers, Y, p, X, q)
return (Y .= X .* q ./ p)
end

function Random.rand!(::PositiveNumbers, pX; σ=one(eltype(pX)), vector_at=nothing)
if vector_at === nothing
pX .= exp(randn() * σ)
else
pX .= vector_at * randn() * σ
end
return pX
end
function Random.rand!(
rng::AbstractRNG,
::PositiveNumbers,
Expand Down
24 changes: 0 additions & 24 deletions src/manifolds/ProductManifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1154,30 +1154,6 @@ function Random.rand(
end
end

function Random.rand!(
M::ProductManifold,
pX;
vector_at=nothing,
parts_kwargs=map(_ -> (;), M.manifolds),
)
if vector_at === nothing
map(
(N, q, kwargs) -> rand!(N, q; kwargs...),
M.manifolds,
submanifold_components(M, pX),
parts_kwargs,
)
else
map(
(N, X, p, kwargs) -> rand!(N, X; vector_at=p, kwargs...),
M.manifolds,
submanifold_components(M, pX),
submanifold_components(M, vector_at),
parts_kwargs,
)
end
return pX
end
function Random.rand!(
rng::AbstractRNG,
M::ProductManifold,
Expand Down
21 changes: 1 addition & 20 deletions src/manifolds/Rotations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,26 +290,7 @@ function Random.rand(
convert(TResult, _fix_random_rotation(A))
end
end
function Random.rand!(M::Rotations, pX; vector_at=nothing, σ::Real=one(eltype(pX)))
if vector_at === nothing
# Special case: Rotations(1) is just zero-dimensional
(manifold_dimension(M) == 0) && return fill!(pX, 1)
A = randn(representation_size(M))
s = diag(sign.(qr(A).R))
D = Diagonal(s)
pX .= qr(A).Q * D
if det(pX) < 0
pX[:, [1, 2]] = pX[:, [2, 1]]
end
else
# Special case: Rotations(1) is just zero-dimensional
(manifold_dimension(M) == 0) && return fill!(pX, 0)
A = σ .* randn(representation_size(M))
pX .= triu(A, 1) .- transpose(triu(A, 1))
normalize!(pX)
end
return pX
end

function Random.rand!(
rng::AbstractRNG,
M::Rotations,
Expand Down
9 changes: 0 additions & 9 deletions src/manifolds/Sphere.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,15 +406,6 @@ project(::AbstractSphere, ::Any, ::Any)

project!(::AbstractSphere, Y, p, X) = (Y .= X .- real(dot(p, X)) .* p)

function Random.rand!(M::AbstractSphere, pX; vector_at=nothing, σ=one(eltype(pX)))
if vector_at === nothing
project!(M, pX, randn(eltype(pX), representation_size(M)))
else
n = σ * randn(eltype(pX), size(pX)) # Gaussian in embedding
project!(M, pX, vector_at, n) #project to TpM (keeps Gaussianness)
end
return pX
end
function Random.rand!(
rng::AbstractRNG,
M::AbstractSphere,
Expand Down
3 changes: 0 additions & 3 deletions src/manifolds/Stiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,6 @@ random Matrix onto the tangent vector at `vector_at`.
"""
rand(::Stiefel; σ::Real=1.0)

function Random.rand!(M::Stiefel, pX; kwargs...)
return Random.rand!(Random.default_rng(), M, pX; kwargs...)
end
function Random.rand!(
rng::AbstractRNG,
M::Stiefel{n,k,𝔽},
Expand Down
45 changes: 0 additions & 45 deletions src/manifolds/SymmetricPositiveDefinite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,51 +403,6 @@ Generate a random symmetric positive definite matrix on the
"""
rand(M::SymmetricPositiveDefinite; σ::Real=1)

function Random.rand!(M::SymmetricPositiveDefinite, pX::SPDPoint; kwargs...)
p = rand(M; kwargs...)
pP = SPDPoint(p; store_p=false, store_sqrt=false, store_sqrt_inv=false)
!ismissing(pX.p) && pX.p .= p
copyto!(pX.eigen.values, pP.eigen.values)
copyto!(pX.eigen.vectors, pP.eigen.vectors)
!ismissing(pX.sqrt) && pX.sqrt .= spd_sqrt(pP)
!ismissing(pX.sqrt_inv) && pX.sqrt_inv .= spd_sqrt_inv(pP)
return pX
end

function Random.rand!(
M::SymmetricPositiveDefinite{N},
pX;
vector_at=nothing,
σ::Real=one(eltype(pX)) /
(vector_at === nothing ? 1 : norm(convert(AbstractMatrix, vector_at))),
tangent_distr=:Gaussian,
) where {N}
if vector_at === nothing
D = Diagonal(1 .+ rand(N)) # random diagonal matrix
s = qr* randn(N, N)) # random q
pX .= Symmetric(s.Q * D * transpose(s.Q))
elseif tangent_distr === :Gaussian
# generate ONB in TxM
vector_at_matrix = convert(AbstractMatrix, vector_at)
I = one(vector_at_matrix)
B = get_basis(M, vector_at, DiagonalizingOrthonormalBasis(I))
Ξ = get_vectors(M, vector_at, B)
Ξx =
vector_transport_to.(
Ref(M),
Ref(I),
Ξ,
Ref(vector_at_matrix),
Ref(ParallelTransport()),
)
pX .= sum* randn(length(Ξx)) .* Ξx)
elseif tangent_distr === :Rician
C = cholesky(Hermitian(vector_at))
R = C.L + sqrt(σ) * triu(randn(size(vector_at, 1), size(vector_at, 2)), 0)
pX .= R * R'
end
return pX
end
function Random.rand!(
rng::AbstractRNG,
M::SymmetricPositiveDefinite{N},
Expand Down
Loading

2 comments on commit 440519b

@mateuszbaran
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/83799

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.61 -m "<description of version>" 440519bbb4869b7312f725067fb8262058eb2925
git push origin v0.8.61

Please sign in to comment.