Skip to content

Commit

Permalink
Fix errors relating to RecursiveArrayTools, Base.stack, and GaussianD…
Browse files Browse the repository at this point in the history
…istributions.jl (#314)

* Roll my own Gaussians to fix the errors

* Fix chi2 which relied on the iterator interface

* Fix an import which relied on GaussianDistributions

* Remove GaussianDistributions

* JuliaFormatter.jl

* Remove some more usages of the iterator interface

* Fix some very unexpected errors relating to DiagonalEK1 on second order ODEs

* Remove another usage of the iterator gaussian

* JuliaFormatter.jl

* Add two more compat bounds

* Simplify the Gaussian implementation

* Remove unused cdf definition

* Small cleanup and better Gaussian printing

* Proper testing of the Gaussians

* JuliaFormatter.jl

* Make the DiagonalEK1-for-SecondOrderODEs case a bit more readable

* Find a reasonable extension of the assert statement

* Fix the `idxs` error

* Test for warning

* Fix another failing test

* Fix docs by fixing one of the `idxs`-related failing plots
  • Loading branch information
nathanaelbosch authored Jun 11, 2024
1 parent 430729f commit a991f99
Show file tree
Hide file tree
Showing 14 changed files with 229 additions and 28 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteHorizonGramians = "b59a298d-d283-4a37-9369-85a9f9a111a5"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixEquations = "99c1a7ee-ab34-5fd5-8076-27c950a045f4"
Expand All @@ -22,6 +21,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PSDMatrices = "fe68d972-6fd8-4755-bdf0-97d4c54cefdc"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down Expand Up @@ -56,7 +56,6 @@ FillArrays = "1.9"
FiniteHorizonGramians = "0.2"
ForwardDiff = "0.10"
FunctionWrappersWrappers = "0.1.3"
GaussianDistributions = "0.5"
Kronecker = "0.5.4"
LinearAlgebra = "1"
MatrixEquations = "2"
Expand All @@ -65,6 +64,7 @@ OrdinaryDiffEq = "6.52"
PSDMatrices = "0.4.7"
PrecompileTools = "1"
Printf = "1"
Random = "1"
RecipesBase = "1"
RecursiveArrayTools = "2, 3"
Reexport = "1"
Expand Down
2 changes: 1 addition & 1 deletion ext/DiffEqDevToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Statistics
using LinearAlgebra

function chi2(gaussian_estimate, actual_value)
μ, Σ = gaussian_estimate
μ, Σ = mean(gaussian_estimate), cov(gaussian_estimate)
d = length(μ)
diff = μ - actual_value
if iszero(Σ)
Expand Down
4 changes: 2 additions & 2 deletions ext/RecipesBaseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ end
x::Matrix{<:Gaussian}, y::Matrix{<:Gaussian},
)
@warn "This plot does not visualize any uncertainties"
xmeans = mean.(x)
ymeans = mean.(y)
xmeans = mean.(x)'
ymeans = mean.(y)'
return xmeans, ymeans
end
@recipe function f(
Expand Down
9 changes: 6 additions & 3 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ __precompile__()
module ProbNumDiffEq

import Base:
copy, copy!, show, size, ndims, similar, isapprox, isequal, iterate, ==, length, zero
copy, copy!, show, size, ndims, similar, isapprox, isequal, iterate, ==, length, zero,
eltype

using LinearAlgebra
import LinearAlgebra: mul!
import LinearAlgebra: mul!, norm_sqr
import Statistics: mean, var, std, cov
import Random: rand, GLOBAL_RNG, AbstractRNG
using Printf
using DocStringExtensions

Expand All @@ -33,7 +35,7 @@ using FillArrays
using MatrixEquations
using DiffEqCallbacks

@reexport using GaussianDistributions
# @reexport using GaussianDistributions

@reexport using PSDMatrices
import PSDMatrices: X_A_Xt, X_A_Xt!, unfactorize
Expand Down Expand Up @@ -69,6 +71,7 @@ export IsometricKroneckerCovariance, DenseCovariance, BlockDiagonalCovariance
abstract type AbstractODEFilterCache <: OrdinaryDiffEq.OrdinaryDiffEqCache end

include("gaussians.jl")
export Gaussian

include("priors/common.jl")
include("priors/iwp.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks/manifoldupdate.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function manifoldupdate!(cache, residualf; maxiters=100, ϵ₁=1e-25, ϵ₂=1e-15)
m, C = cache.x
m, C = mean(cache.x), cov(cache.x)

# Create some caches
@unpack SolProj, tmp, H, x_tmp = cache
Expand All @@ -11,7 +11,7 @@ function manifoldupdate!(cache, residualf; maxiters=100, ϵ₁=1e-25, ϵ₂=1e-1
_K1, _K2 = cache.C_DxD[:, 1:d], cache.C_2DxD[1:D, 1:d]

S = PSDMatrix(C.R[:, 1:d])
m_tmp, C_tmp = x_tmp
m_tmp, C_tmp = mean(x_tmp), cov(x_tmp)

m_i = copy(m)
local m_i_new, C_i_new
Expand Down
2 changes: 1 addition & 1 deletion src/data_likelihoods/fenrir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ function fit_pnsolution_to_data!(
end

function measure_and_update!(x, u, H, R::PSDMatrix, cache)
z, S = cache.m_tmp
z, S = mean(cache.m_tmp), cov(cache.m_tmp)
_matmul!(z, H, x.μ)
z .-= u
S = PSDMatrix(make_obscov_sqrt(x.Σ.R, H, R.R))
Expand Down
11 changes: 10 additions & 1 deletion src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@ function calc_H!(H, integ, cache)
elseif integ.alg isa DiagonalEK1
calc_H_EK0!(H, integ, cache)
OrdinaryDiffEq.calc_J!(ddu, integ, cache, true)
ddu_diag = Diagonal(ddu)
_ddu = size(ddu, 2) != d ? view(ddu, 1:d, :) : ddu
ddu_diag = if size(ddu, 2) == d
# the normal case: just extract the diagonal
Diagonal(_ddu)
else
# the "SecondOrderODEProblem" case: since f(ddu, du, u, p, t) we need a bit more
topleft = view(ddu, 1:d, 1:d)
topright = view(ddu, 1:d, d+1:2d)
BlocksOfDiagonals([[topleft[i, i] topright[i, i];] for i in 1:d])
end
_matmul!(H, ddu_diag, cache.SolProj, -1.0, 1.0)
else
error("Unknown algorithm")
Expand Down
2 changes: 1 addition & 1 deletion src/diffusions/calibration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ function local_diagonal_diffusion(cache)
@unpack d, q, H, Qh, measurement, m_tmp = cache
tmp = m_tmp.μ
@unpack local_diffusion = cache
@assert (H == cache.E1) || (H == cache.E2)
@assert (H == cache.E1) || (H == cache.E2) || H isa BlocksOfDiagonals

z = measurement.μ
# HQH = H * unfactorize(Qh) * H'
Expand Down
8 changes: 4 additions & 4 deletions src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ Joseph / square-root form.
For better performance, we recommend to use the non-allocating [`update!`](@ref).
"""
function update(x::Gaussian, measurement::Gaussian, H::AbstractMatrix)
m, C = x
z, S = measurement
m, C = mean(x), cov(x)
z, S = mean(measurement), cov(measurement)

K = C * H' * inv(S)
m_new = m - K * z
Expand All @@ -31,8 +31,8 @@ function update(x::Gaussian, measurement::Gaussian, H::AbstractMatrix)
return Gaussian(m_new, C_new)
end
function update(x::SRGaussian, measurement::Gaussian, H::AbstractMatrix)
m, C = x
z, S = measurement
m, C = mean(x), cov(x)
z, S = mean(measurement), cov(measurement)

K = C * H' * inv(S)
m_new = m - K * z
Expand Down
122 changes: 112 additions & 10 deletions src/gaussians.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,132 @@
############################################################################################
# Useful things when working with GaussianDistributions.Gaussian
# `Gaussian` distributions
# Based on @mschauer's GaussianDistributions.jl
############################################################################################
"""
Gaussian(μ, Σ) -> P
Gaussian distribution with mean `μ` and covariance `Σ`. Defines `rand(P)` and `(log-)pdf(P, x)`.
Designed to work with `Number`s, `UniformScaling`s, `StaticArrays` and `PSD`-matrices.
Implementation details: On `Σ` the functions `logdet`, `whiten` and `unwhiten`
(or `chol` as fallback for the latter two) are called.
"""
struct Gaussian{T,S}
μ::T
Σ::S
Gaussian::T, Σ::S) where {T,S} = new{T,S}(μ, Σ)
end
Base.convert(::Type{Gaussian{T,S}}, g::Gaussian) where {T,S} =
Gaussian(convert(T, g.μ), convert(S, g.Σ))

# Base
Base.:(==)(g1::Gaussian, g2::Gaussian) = g1.μ == g2.μ && g1.Σ == g2.Σ
Base.isapprox(g1::Gaussian, g2::Gaussian; kwargs...) =
isapprox(g1.μ, g2.μ; kwargs...) && isapprox(g1.Σ, g2.Σ; kwargs...)
copy(P::Gaussian) = Gaussian(copy(P.μ), copy(P.Σ))
similar(P::Gaussian) = Gaussian(similar(P.μ), similar(P.Σ))
Base.copyto!(P::AbstractArray{<:Gaussian}, idx::Integer, el::Gaussian) =
(P[idx] = copy(el); P)
function Base.copy!(dst::Gaussian, src::Gaussian)
copy!(dst.μ, src.μ)
copy!(dst.Σ, src.Σ)
return dst
end
RecursiveArrayTools.recursivecopy(P::Gaussian) = copy(P)
RecursiveArrayTools.recursivecopy!(dst::Gaussian, src::Gaussian) = copy!(dst, src)
show(io::IO, g::Gaussian) = print(io, "Gaussian($(g.μ), $(g.Σ))")
show(io::IO, ::MIME"text/plain", g::Gaussian{T,S}) where {T,S} =
print(io, "Gaussian{$T,$S}($(g.μ), $(g.Σ))")
length(P::Gaussian) = length(mean(P))
size(g::Gaussian) = size(g.μ)
ndims(g::Gaussian) = ndims(g.μ)
eltype(::Type{G}) where {G<:Gaussian} = G
Base.@propagate_inbounds Base.getindex(P::Gaussian, i::Integer) =
Gaussian(P.μ[i], diag(P.Σ)[i])

# Statistics
mean(P::Gaussian) = P.μ
cov(P::Gaussian) = P.Σ
var(P::Gaussian{<:Number}) = P.Σ
std(P::Gaussian{<:Number}) = sqrt(var(P))
var(g::Gaussian) = diag(g.Σ)
std(g::Gaussian) = sqrt.(diag(g.Σ))

dim(P::Gaussian) = length(P.μ)

# whiten(Σ::PSD, z) = Σ.σ\z
whiten(Σ, z) = cholesky(Σ).U' \ z
whiten::Number, z) = sqrt(Σ) \ z
whiten::UniformScaling, z) = sqrt.λ) \ z

# unwhiten(Σ::PSD, z) = Σ.σ*z
unwhiten(Σ, z) = cholesky(Σ).U' * z
unwhiten::Number, z) = sqrt(Σ) * z
unwhiten::UniformScaling, z) = sqrt.λ) * z

sqmahal(P::Gaussian, x) = norm_sqr(whiten(P.Σ, x - P.μ))

rand(P::Gaussian) = rand(GLOBAL_RNG, P)
rand(RNG::AbstractRNG, P::Gaussian) = P.μ + unwhiten(P.Σ, randn(RNG, typeof(P.μ)))
rand(RNG::AbstractRNG, P::Gaussian{Vector{T}}) where {T} =
P.μ + unwhiten(P.Σ, randn(RNG, T, length(P.μ)))
rand(RNG::AbstractRNG, P::Gaussian{<:Number}) =
P.μ + sqrt(P.Σ) * randn(RNG, typeof(one(P.μ)))

_logdet(Σ, d) = logdet(Σ)
_logdet(J::UniformScaling, d) = log(J.λ) * d
logpdf(P::Gaussian, x) = -(sqmahal(P, x) + _logdet(P.Σ, dim(P)) + dim(P) * log(2pi)) / 2
pdf(P::Gaussian, x) = exp(logpdf(P::Gaussian, x))

Base.:+(g::Gaussian, vec) = Gaussian(g.μ + vec, g.Σ)
Base.:+(vec, g::Gaussian) = g + vec
Base.:-(g::Gaussian, vec) = g + (-vec)
Base.:*(M, g::Gaussian) = Gaussian(M * g.μ, X_A_Xt(g.Σ, M))

function rand_scalar(RNG::AbstractRNG, P::Gaussian{T}, dims) where {T}
X = zeros(T, dims)
for i in 1:length(X)
X[i] = rand(RNG, P)
end
X
end

function rand_vector(
RNG::AbstractRNG,
P::Gaussian{Vector{T}},
dims::Union{Integer,NTuple},
) where {T}
X = zeros(T, dim(P), dims...)
for i in 1:prod(dims)
X[:, i] = rand(RNG, P)
end
X
end
rand(RNG::AbstractRNG, P::Gaussian, dim::Integer) = rand_scalar(RNG, P, dim)
rand(RNG::AbstractRNG, P::Gaussian, dims::Tuple{Vararg{Int64,N}} where {N}) =
rand_scalar(RNG, P, dims)

rand(RNG::AbstractRNG, P::Gaussian{Vector{T}}, dim::Integer) where {T} =
rand_vector(RNG, P, dim)
rand(
RNG::AbstractRNG,
P::Gaussian{Vector{T}},
dims::Tuple{Vararg{Int64,N}} where {N},
) where {T} = rand_vector(RNG, P, dims)
rand(P::Gaussian, dims::Tuple{Vararg{Int64,N}} where {N}) = rand(GLOBAL_RNG, P, dims)
rand(P::Gaussian, dim::Integer) = rand(GLOBAL_RNG, P, dim)

# RecursiveArrayTools
RecursiveArrayTools.recursivecopy(P::Gaussian) = copy(P)
RecursiveArrayTools.recursivecopy!(dst::Gaussian, src::Gaussian) = copy!(dst, src)

# Print
show(io::IO, g::Gaussian) = print(io, "Gaussian(μ=$(g.μ), Σ=$(g.Σ))")
show(io::IO, ::MIME"text/plain", g::Gaussian{T,S}) where {T,S} = begin
println(io, "Gaussian{$T,$S}(")
println(io, " μ=$(g.μ),")
println(io, " Σ=$(g.Σ)")
print(io, ")")
end

############################################################################################
# `SRGaussian`: Gaussians with PDFMatrix covariances
############################################################################################
const SRGaussian{T,S} = Gaussian{VM,PSDMatrix{T,S}} where {VM<:AbstractVecOrMat{T}}
Base.:*(M::AbstractMatrix, g::SRGaussian) = Gaussian(M * g.μ, X_A_Xt(g.Σ, M))
# GaussianDistributions.whiten(Σ::PSDMatrix, z) = Σ.L\z

function _gaussian_mul!(g_out::SRGaussian, M::AbstractMatrix, g_in::SRGaussian)
_matmul!(g_out.μ, M, g_in.μ)
fast_X_A_Xt!(g_out.Σ, g_in.Σ, M)
Expand Down
26 changes: 26 additions & 0 deletions src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,32 @@ function (interp::ODEFilterPosterior)(
u = proj * x
return Gaussian(u.μ[idxs], diag(u.Σ)[idxs])
end
function (interp::ODEFilterPosterior)(
t::Real,
idxs::AbstractVector{<:Integer},
::Type{deriv},
p,
continuity,
) where {deriv}
q = interp.cache.q
dv = deriv.parameters[1]
proj = if deriv == Val{0}
interp.cache.SolProj
elseif dv <= q
interp.cache.Proj(dv)
else
throw(ArgumentError("We can only provide derivatives up to $q but you requested $dv"))
end
x = interpolate(
t, interp.ts, interp.x_filt, interp.x_smooth, interp.diffusions, interp.cache;
smoothed=interp.smooth)
u = proj * x
P = zeros(Bool, length(idxs), length(u))
for (i, idx) in enumerate(idxs)
P[i, idx] = 1
end
return P * u
end
function (interp::ODEFilterPosterior)(
t::AbstractVector{<:Real},
idxs,
Expand Down
2 changes: 1 addition & 1 deletion test/core/filtering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using LinearAlgebra
import ProbNumDiffEq: IsometricKroneckerProduct, BlocksOfDiagonals
import ProbNumDiffEq as PNDE
using FillArrays
import ProbNumDiffEq.GaussianDistributions: logpdf
import ProbNumDiffEq: logpdf

@testset "PREDICT" begin
# Setup
Expand Down
Loading

0 comments on commit a991f99

Please sign in to comment.