Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix errors relating to RecursiveArrayTools, Base.stack, and GaussianDistributions.jl #314

Merged
merged 21 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f9697fb
Roll my own Gaussians to fix the errors
nathanaelbosch Jun 9, 2024
2311aec
Fix chi2 which relied on the iterator interface
nathanaelbosch Jun 9, 2024
2d27487
Fix an import which relied on GaussianDistributions
nathanaelbosch Jun 9, 2024
521e198
Remove GaussianDistributions
nathanaelbosch Jun 9, 2024
7119e1c
JuliaFormatter.jl
nathanaelbosch Jun 9, 2024
bf93e03
Remove some more usages of the iterator interface
nathanaelbosch Jun 9, 2024
f4da817
Fix some very unexpected errors relating to DiagonalEK1 on second ord…
nathanaelbosch Jun 9, 2024
e092525
Remove another usage of the iterator gaussian
nathanaelbosch Jun 9, 2024
c42ac64
JuliaFormatter.jl
nathanaelbosch Jun 9, 2024
55e13d6
Add two more compat bounds
nathanaelbosch Jun 9, 2024
48bc292
Simplify the Gaussian implementation
nathanaelbosch Jun 9, 2024
9ab8d99
Remove unused cdf definition
nathanaelbosch Jun 9, 2024
527fe80
Small cleanup and better Gaussian printing
nathanaelbosch Jun 9, 2024
6fcd0b8
Proper testing of the Gaussians
nathanaelbosch Jun 9, 2024
93bba21
JuliaFormatter.jl
nathanaelbosch Jun 9, 2024
ccc80d2
Make the DiagonalEK1-for-SecondOrderODEs case a bit more readable
nathanaelbosch Jun 9, 2024
b66ad29
Find a reasonable extension of the assert statement
nathanaelbosch Jun 9, 2024
b6913f0
Fix the `idxs` error
nathanaelbosch Jun 9, 2024
7ea39ed
Test for warning
nathanaelbosch Jun 9, 2024
73df4d0
Fix another failing test
nathanaelbosch Jun 9, 2024
eb40e03
Fix docs by fixing one of the `idxs`-related failing plots
nathanaelbosch Jun 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteHorizonGramians = "b59a298d-d283-4a37-9369-85a9f9a111a5"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixEquations = "99c1a7ee-ab34-5fd5-8076-27c950a045f4"
Expand All @@ -22,6 +21,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PSDMatrices = "fe68d972-6fd8-4755-bdf0-97d4c54cefdc"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down Expand Up @@ -56,7 +56,6 @@ FillArrays = "1.9"
FiniteHorizonGramians = "0.2"
ForwardDiff = "0.10"
FunctionWrappersWrappers = "0.1.3"
GaussianDistributions = "0.5"
Kronecker = "0.5.4"
LinearAlgebra = "1"
MatrixEquations = "2"
Expand All @@ -65,6 +64,7 @@ OrdinaryDiffEq = "6.52"
PSDMatrices = "0.4.7"
PrecompileTools = "1"
Printf = "1"
Random = "1"
RecipesBase = "1"
RecursiveArrayTools = "2, 3"
Reexport = "1"
Expand Down
2 changes: 1 addition & 1 deletion ext/DiffEqDevToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Statistics
using LinearAlgebra

function chi2(gaussian_estimate, actual_value)
μ, Σ = gaussian_estimate
μ, Σ = mean(gaussian_estimate), cov(gaussian_estimate)
d = length(μ)
diff = μ - actual_value
if iszero(Σ)
Expand Down
4 changes: 2 additions & 2 deletions ext/RecipesBaseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ end
x::Matrix{<:Gaussian}, y::Matrix{<:Gaussian},
)
@warn "This plot does not visualize any uncertainties"
xmeans = mean.(x)
ymeans = mean.(y)
xmeans = mean.(x)'
ymeans = mean.(y)'
return xmeans, ymeans
end
@recipe function f(
Expand Down
9 changes: 6 additions & 3 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ __precompile__()
module ProbNumDiffEq

import Base:
copy, copy!, show, size, ndims, similar, isapprox, isequal, iterate, ==, length, zero
copy, copy!, show, size, ndims, similar, isapprox, isequal, iterate, ==, length, zero,
eltype

using LinearAlgebra
import LinearAlgebra: mul!
import LinearAlgebra: mul!, norm_sqr
import Statistics: mean, var, std, cov
import Random: rand, GLOBAL_RNG, AbstractRNG
using Printf
using DocStringExtensions

Expand All @@ -33,7 +35,7 @@ using FillArrays
using MatrixEquations
using DiffEqCallbacks

@reexport using GaussianDistributions
# @reexport using GaussianDistributions

@reexport using PSDMatrices
import PSDMatrices: X_A_Xt, X_A_Xt!, unfactorize
Expand Down Expand Up @@ -69,6 +71,7 @@ export IsometricKroneckerCovariance, DenseCovariance, BlockDiagonalCovariance
abstract type AbstractODEFilterCache <: OrdinaryDiffEq.OrdinaryDiffEqCache end

include("gaussians.jl")
export Gaussian

include("priors/common.jl")
include("priors/iwp.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks/manifoldupdate.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function manifoldupdate!(cache, residualf; maxiters=100, ϵ₁=1e-25, ϵ₂=1e-15)
m, C = cache.x
m, C = mean(cache.x), cov(cache.x)

# Create some caches
@unpack SolProj, tmp, H, x_tmp = cache
Expand All @@ -11,7 +11,7 @@ function manifoldupdate!(cache, residualf; maxiters=100, ϵ₁=1e-25, ϵ₂=1e-1
_K1, _K2 = cache.C_DxD[:, 1:d], cache.C_2DxD[1:D, 1:d]

S = PSDMatrix(C.R[:, 1:d])
m_tmp, C_tmp = x_tmp
m_tmp, C_tmp = mean(x_tmp), cov(x_tmp)

m_i = copy(m)
local m_i_new, C_i_new
Expand Down
2 changes: 1 addition & 1 deletion src/data_likelihoods/fenrir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ function fit_pnsolution_to_data!(
end

function measure_and_update!(x, u, H, R::PSDMatrix, cache)
z, S = cache.m_tmp
z, S = mean(cache.m_tmp), cov(cache.m_tmp)
_matmul!(z, H, x.μ)
z .-= u
S = PSDMatrix(make_obscov_sqrt(x.Σ.R, H, R.R))
Expand Down
11 changes: 10 additions & 1 deletion src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@ function calc_H!(H, integ, cache)
elseif integ.alg isa DiagonalEK1
calc_H_EK0!(H, integ, cache)
OrdinaryDiffEq.calc_J!(ddu, integ, cache, true)
ddu_diag = Diagonal(ddu)
_ddu = size(ddu, 2) != d ? view(ddu, 1:d, :) : ddu
nathanaelbosch marked this conversation as resolved.
Show resolved Hide resolved
ddu_diag = if size(ddu, 2) == d
# the normal case: just extract the diagonal
Diagonal(_ddu)
else
# the "SecondOrderODEProblem" case: since f(ddu, du, u, p, t) we need a bit more
topleft = view(ddu, 1:d, 1:d)
topright = view(ddu, 1:d, d+1:2d)
BlocksOfDiagonals([[topleft[i, i] topright[i, i];] for i in 1:d])
end
_matmul!(H, ddu_diag, cache.SolProj, -1.0, 1.0)
else
error("Unknown algorithm")
Expand Down
2 changes: 1 addition & 1 deletion src/diffusions/calibration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ function local_diagonal_diffusion(cache)
@unpack d, q, H, Qh, measurement, m_tmp = cache
tmp = m_tmp.μ
@unpack local_diffusion = cache
@assert (H == cache.E1) || (H == cache.E2)
nathanaelbosch marked this conversation as resolved.
Show resolved Hide resolved
@assert (H == cache.E1) || (H == cache.E2) || H isa BlocksOfDiagonals

z = measurement.μ
# HQH = H * unfactorize(Qh) * H'
Expand Down
8 changes: 4 additions & 4 deletions src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
For better performance, we recommend to use the non-allocating [`update!`](@ref).
"""
function update(x::Gaussian, measurement::Gaussian, H::AbstractMatrix)
m, C = x
z, S = measurement
m, C = mean(x), cov(x)
z, S = mean(measurement), cov(measurement)

K = C * H' * inv(S)
m_new = m - K * z
Expand All @@ -31,8 +31,8 @@
return Gaussian(m_new, C_new)
end
function update(x::SRGaussian, measurement::Gaussian, H::AbstractMatrix)
m, C = x
z, S = measurement
m, C = mean(x), cov(x)
z, S = mean(measurement), cov(measurement)

Check warning on line 35 in src/filtering/update.jl

View check run for this annotation

Codecov / codecov/patch

src/filtering/update.jl#L34-L35

Added lines #L34 - L35 were not covered by tests

K = C * H' * inv(S)
m_new = m - K * z
Expand Down
122 changes: 112 additions & 10 deletions src/gaussians.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,132 @@
############################################################################################
# Useful things when working with GaussianDistributions.Gaussian
# `Gaussian` distributions
# Based on @mschauer's GaussianDistributions.jl
############################################################################################
"""
Gaussian(μ, Σ) -> P

Gaussian distribution with mean `μ` and covariance `Σ`. Defines `rand(P)` and `(log-)pdf(P, x)`.
Designed to work with `Number`s, `UniformScaling`s, `StaticArrays` and `PSD`-matrices.

Implementation details: On `Σ` the functions `logdet`, `whiten` and `unwhiten`
(or `chol` as fallback for the latter two) are called.
"""
struct Gaussian{T,S}
μ::T
Σ::S
Gaussian(μ::T, Σ::S) where {T,S} = new{T,S}(μ, Σ)
end
Base.convert(::Type{Gaussian{T,S}}, g::Gaussian) where {T,S} =
Gaussian(convert(T, g.μ), convert(S, g.Σ))

# Base
Base.:(==)(g1::Gaussian, g2::Gaussian) = g1.μ == g2.μ && g1.Σ == g2.Σ
Base.isapprox(g1::Gaussian, g2::Gaussian; kwargs...) =
isapprox(g1.μ, g2.μ; kwargs...) && isapprox(g1.Σ, g2.Σ; kwargs...)
copy(P::Gaussian) = Gaussian(copy(P.μ), copy(P.Σ))
similar(P::Gaussian) = Gaussian(similar(P.μ), similar(P.Σ))
Base.copyto!(P::AbstractArray{<:Gaussian}, idx::Integer, el::Gaussian) =
(P[idx] = copy(el); P)

Check warning on line 29 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L29

Added line #L29 was not covered by tests
function Base.copy!(dst::Gaussian, src::Gaussian)
copy!(dst.μ, src.μ)
copy!(dst.Σ, src.Σ)
return dst
end
RecursiveArrayTools.recursivecopy(P::Gaussian) = copy(P)
RecursiveArrayTools.recursivecopy!(dst::Gaussian, src::Gaussian) = copy!(dst, src)
show(io::IO, g::Gaussian) = print(io, "Gaussian($(g.μ), $(g.Σ))")
show(io::IO, ::MIME"text/plain", g::Gaussian{T,S}) where {T,S} =
print(io, "Gaussian{$T,$S}($(g.μ), $(g.Σ))")
length(P::Gaussian) = length(mean(P))
size(g::Gaussian) = size(g.μ)
ndims(g::Gaussian) = ndims(g.μ)
eltype(::Type{G}) where {G<:Gaussian} = G
Base.@propagate_inbounds Base.getindex(P::Gaussian, i::Integer) =
Gaussian(P.μ[i], diag(P.Σ)[i])

# Statistics
mean(P::Gaussian) = P.μ
cov(P::Gaussian) = P.Σ
var(P::Gaussian{<:Number}) = P.Σ
std(P::Gaussian{<:Number}) = sqrt(var(P))
var(g::Gaussian) = diag(g.Σ)
std(g::Gaussian) = sqrt.(diag(g.Σ))

dim(P::Gaussian) = length(P.μ)

# whiten(Σ::PSD, z) = Σ.σ\z
whiten(Σ, z) = cholesky(Σ).U' \ z
whiten(Σ::Number, z) = sqrt(Σ) \ z
whiten(Σ::UniformScaling, z) = sqrt(Σ.λ) \ z

Check warning on line 54 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L54

Added line #L54 was not covered by tests

# unwhiten(Σ::PSD, z) = Σ.σ*z
unwhiten(Σ, z) = cholesky(Σ).U' * z
unwhiten(Σ::Number, z) = sqrt(Σ) * z
unwhiten(Σ::UniformScaling, z) = sqrt(Σ.λ) * z

Check warning on line 59 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L58-L59

Added lines #L58 - L59 were not covered by tests

sqmahal(P::Gaussian, x) = norm_sqr(whiten(P.Σ, x - P.μ))

rand(P::Gaussian) = rand(GLOBAL_RNG, P)
rand(RNG::AbstractRNG, P::Gaussian) = P.μ + unwhiten(P.Σ, randn(RNG, typeof(P.μ)))

Check warning on line 64 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L64

Added line #L64 was not covered by tests
rand(RNG::AbstractRNG, P::Gaussian{Vector{T}}) where {T} =
P.μ + unwhiten(P.Σ, randn(RNG, T, length(P.μ)))
rand(RNG::AbstractRNG, P::Gaussian{<:Number}) =
P.μ + sqrt(P.Σ) * randn(RNG, typeof(one(P.μ)))

_logdet(Σ, d) = logdet(Σ)
_logdet(J::UniformScaling, d) = log(J.λ) * d

Check warning on line 71 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L71

Added line #L71 was not covered by tests
logpdf(P::Gaussian, x) = -(sqmahal(P, x) + _logdet(P.Σ, dim(P)) + dim(P) * log(2pi)) / 2
pdf(P::Gaussian, x) = exp(logpdf(P::Gaussian, x))

Base.:+(g::Gaussian, vec) = Gaussian(g.μ + vec, g.Σ)
Base.:+(vec, g::Gaussian) = g + vec
Base.:-(g::Gaussian, vec) = g + (-vec)
Base.:*(M, g::Gaussian) = Gaussian(M * g.μ, X_A_Xt(g.Σ, M))

function rand_scalar(RNG::AbstractRNG, P::Gaussian{T}, dims) where {T}
X = zeros(T, dims)
for i in 1:length(X)
X[i] = rand(RNG, P)
end
X

Check warning on line 85 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L80-L85

Added lines #L80 - L85 were not covered by tests
end

function rand_vector(
RNG::AbstractRNG,
P::Gaussian{Vector{T}},
dims::Union{Integer,NTuple},
) where {T}
X = zeros(T, dim(P), dims...)
for i in 1:prod(dims)
X[:, i] = rand(RNG, P)
end
X

Check warning on line 97 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L97

Added line #L97 was not covered by tests
end
rand(RNG::AbstractRNG, P::Gaussian, dim::Integer) = rand_scalar(RNG, P, dim)
rand(RNG::AbstractRNG, P::Gaussian, dims::Tuple{Vararg{Int64,N}} where {N}) =

Check warning on line 100 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L99-L100

Added lines #L99 - L100 were not covered by tests
rand_scalar(RNG, P, dims)

rand(RNG::AbstractRNG, P::Gaussian{Vector{T}}, dim::Integer) where {T} =
rand_vector(RNG, P, dim)
rand(

Check warning on line 105 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L105

Added line #L105 was not covered by tests
RNG::AbstractRNG,
P::Gaussian{Vector{T}},
dims::Tuple{Vararg{Int64,N}} where {N},
) where {T} = rand_vector(RNG, P, dims)
rand(P::Gaussian, dims::Tuple{Vararg{Int64,N}} where {N}) = rand(GLOBAL_RNG, P, dims)

Check warning on line 110 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L110

Added line #L110 was not covered by tests
rand(P::Gaussian, dim::Integer) = rand(GLOBAL_RNG, P, dim)

# RecursiveArrayTools
RecursiveArrayTools.recursivecopy(P::Gaussian) = copy(P)
RecursiveArrayTools.recursivecopy!(dst::Gaussian, src::Gaussian) = copy!(dst, src)

# Print
show(io::IO, g::Gaussian) = print(io, "Gaussian(μ=$(g.μ), Σ=$(g.Σ))")
show(io::IO, ::MIME"text/plain", g::Gaussian{T,S}) where {T,S} = begin
println(io, "Gaussian{$T,$S}(")
println(io, " μ=$(g.μ),")
println(io, " Σ=$(g.Σ)")
print(io, ")")

Check warning on line 123 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L118-L123

Added lines #L118 - L123 were not covered by tests
end

############################################################################################
# `SRGaussian`: Gaussians with PDFMatrix covariances
############################################################################################
const SRGaussian{T,S} = Gaussian{VM,PSDMatrix{T,S}} where {VM<:AbstractVecOrMat{T}}
Base.:*(M::AbstractMatrix, g::SRGaussian) = Gaussian(M * g.μ, X_A_Xt(g.Σ, M))
# GaussianDistributions.whiten(Σ::PSDMatrix, z) = Σ.L\z

function _gaussian_mul!(g_out::SRGaussian, M::AbstractMatrix, g_in::SRGaussian)
_matmul!(g_out.μ, M, g_in.μ)
fast_X_A_Xt!(g_out.Σ, g_in.Σ, M)
Expand Down
26 changes: 26 additions & 0 deletions src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,32 @@
u = proj * x
return Gaussian(u.μ[idxs], diag(u.Σ)[idxs])
end
function (interp::ODEFilterPosterior)(
t::Real,
idxs::AbstractVector{<:Integer},
::Type{deriv},
p,
continuity,
) where {deriv}
q = interp.cache.q
dv = deriv.parameters[1]
proj = if deriv == Val{0}
interp.cache.SolProj
elseif dv <= q
interp.cache.Proj(dv)

Check warning on line 296 in src/solution.jl

View check run for this annotation

Codecov / codecov/patch

src/solution.jl#L295-L296

Added lines #L295 - L296 were not covered by tests
else
throw(ArgumentError("We can only provide derivatives up to $q but you requested $dv"))

Check warning on line 298 in src/solution.jl

View check run for this annotation

Codecov / codecov/patch

src/solution.jl#L298

Added line #L298 was not covered by tests
end
x = interpolate(
t, interp.ts, interp.x_filt, interp.x_smooth, interp.diffusions, interp.cache;
smoothed=interp.smooth)
u = proj * x
P = zeros(Bool, length(idxs), length(u))
for (i, idx) in enumerate(idxs)
P[i, idx] = 1
end
return P * u
end
function (interp::ODEFilterPosterior)(
t::AbstractVector{<:Real},
idxs,
Expand Down
2 changes: 1 addition & 1 deletion test/core/filtering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using LinearAlgebra
import ProbNumDiffEq: IsometricKroneckerProduct, BlocksOfDiagonals
import ProbNumDiffEq as PNDE
using FillArrays
import ProbNumDiffEq.GaussianDistributions: logpdf
import ProbNumDiffEq: logpdf

@testset "PREDICT" begin
# Setup
Expand Down
Loading