Skip to content

Commit

Permalink
Roll my own Gaussians to fix the errors
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Jun 9, 2024
1 parent 430729f commit f9697fb
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.15.0"
ArrayAllocators = "c9d4266f-a5cb-439d-837c-c97b191379f5"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -22,6 +23,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PSDMatrices = "fe68d972-6fd8-4755-bdf0-97d4c54cefdc"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down
3 changes: 2 additions & 1 deletion src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ using FillArrays
using MatrixEquations
using DiffEqCallbacks

@reexport using GaussianDistributions
# @reexport using GaussianDistributions

@reexport using PSDMatrices
import PSDMatrices: X_A_Xt, X_A_Xt!, unfactorize
Expand Down Expand Up @@ -69,6 +69,7 @@ export IsometricKroneckerCovariance, DenseCovariance, BlockDiagonalCovariance
abstract type AbstractODEFilterCache <: OrdinaryDiffEq.OrdinaryDiffEqCache end

include("gaussians.jl")
export Gaussian

include("priors/common.jl")
include("priors/iwp.jl")
Expand Down
113 changes: 113 additions & 0 deletions src/gaussians.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,121 @@
using LinearAlgebra, Random, Statistics
using Distributions
using LinearAlgebra: norm_sqr
import Random: rand, GLOBAL_RNG
import Statistics: mean, cov, var, std
import Distributions: pdf, logpdf, sqmahal, cdf, quantile
import LinearAlgebra: cholesky
import Base: size, iterate, length

sumlogdiag::Float64, _) = log(Σ)
sumlogdiag(Σ, d) = sum(log.(diag(Σ)))
sumlogdiag(J::UniformScaling, d) = log(J.λ)*d

# _logdet(Σ::PSD, d) = 2*sumlogdiag(Σ.σ, d)
_logdet(Σ, d) = logdet(Σ)
_logdet(J::UniformScaling, d) = log(J.λ) * d

"""
Gaussian(μ, Σ) -> P
Gaussian distribution with mean `μ` and covariance `Σ`. Defines `rand(P)` and `(log-)pdf(P, x)`.
Designed to work with `Number`s, `UniformScaling`s, `StaticArrays` and `PSD`-matrices.
Implementation details: On `Σ` the functions `logdet`, `whiten` and `unwhiten`
(or `chol` as fallback for the latter two) are called.
"""
struct Gaussian{T,S}
μ::T
Σ::S
Gaussian{T,S}(μ, Σ) where {T,S} = new(μ, Σ)
Gaussian::T, Σ::S) where {T,S} = new{T,S}(μ, Σ)
end

Base.:(==)(g1::Gaussian, g2::Gaussian) = g1.μ == g2.μ && g1.Σ == g2.Σ
Base.isapprox(g1::Gaussian, g2::Gaussian; kwargs...) =
isapprox(g1.μ, g2.μ; kwargs...) && isapprox(g1.Σ, g2.Σ; kwargs...)
Gaussian() = Gaussian(0.0, 1.0)
mean(P::Gaussian) = P.μ
cov(P::Gaussian) = P.Σ
var(P::Gaussian{<:Number}) = P.Σ
std(P::Gaussian{<:Number}) = sqrt(var(P))
Base.convert(::Type{Gaussian{T, S}}, g::Gaussian) where {T, S} =
Gaussian(convert(T, g.μ), convert(S, g.Σ))

dim(P::Gaussian) = length(P.μ)
# whiten(Σ::PSD, z) = Σ.σ\z
whiten(Σ, z) = cholesky(Σ).U'\z
whiten::Number, z) = sqrt(Σ)\z
whiten::UniformScaling, z) = sqrt.λ)\z

# unwhiten(Σ::PSD, z) = Σ.σ*z
unwhiten(Σ, z) = cholesky(Σ).U'*z
unwhiten::Number, z) = sqrt(Σ)*z
unwhiten::UniformScaling, z) = sqrt.λ)*z

sqmahal(P::Gaussian, x) = norm_sqr(whiten(P.Σ, x - P.μ))

rand(P::Gaussian) = rand(GLOBAL_RNG, P)
rand(RNG::AbstractRNG, P::Gaussian) = P.μ + unwhiten(P.Σ, randn(RNG, typeof(P.μ)))
rand(RNG::AbstractRNG, P::Gaussian{Vector{T}}) where T =
P.μ + unwhiten(P.Σ, randn(RNG, T, length(P.μ)))
rand(RNG::AbstractRNG, P::Gaussian{<:Number}) = P.μ + sqrt(P.Σ)*randn(RNG, typeof(one(P.μ)))

logpdf(P::Gaussian, x) = -(sqmahal(P,x) + _logdet(P.Σ, dim(P)) + dim(P)*log(2pi))/2
pdf(P::Gaussian, x) = exp(logpdf(P::Gaussian, x))
cdf(P::Gaussian{Number}, x) = Distributions.normcdf(P.μ, sqrt(P.Σ), x)

Base.:+(g::Gaussian, vec) = Gaussian(g.μ + vec, g.Σ)
Base.:+(vec, g::Gaussian) = g + vec
Base.:-(g::Gaussian, vec) = g + (-vec)
Base.:*(M, g::Gaussian) = Gaussian(M * g.μ, M * g.Σ * M')
(g1::Gaussian, g2::Gaussian) = Gaussian(g1.μ + g2.μ, g1.Σ + g2.Σ)
(vec, g::Gaussian) = vec + g
(g::Gaussian, vec) = g + vec
const independent_sum =

function rand_scalar(RNG::AbstractRNG, P::Gaussian{T}, dims) where {T}
X = zeros(T, dims)
for i in 1:length(X)
X[i] = rand(RNG, P)
end
X
end

function rand_vector(RNG::AbstractRNG, P::Gaussian{Vector{T}}, dims::Union{Integer, NTuple}) where {T}
X = zeros(T, dim(P), dims...)
for i in 1:prod(dims)
X[:, i] = rand(RNG, P)
end
X
end
rand(RNG::AbstractRNG, P::Gaussian, dim::Integer) = rand_scalar(RNG, P, dim)
rand(RNG::AbstractRNG, P::Gaussian, dims::Tuple{Vararg{Int64,N}} where N) = rand_scalar(RNG, P, dims)

rand(RNG::AbstractRNG, P::Gaussian{Vector{T}}, dim::Integer) where {T} = rand_vector(RNG, P, dim)
rand(RNG::AbstractRNG, P::Gaussian{Vector{T}}, dims::Tuple{Vararg{Int64,N}} where N) where {T} = rand_vector(RNG, P, dims)
rand(P::Gaussian, dims::Tuple{Vararg{Int64,N}} where N) = rand(GLOBAL_RNG, P, dims)
rand(P::Gaussian, dim::Integer) = rand(GLOBAL_RNG, P, dim)

############################################################################################
# Useful things when working with GaussianDistributions.Gaussian
############################################################################################
copy(P::Gaussian) = Gaussian(copy(P.μ), copy(P.Σ))
similar(P::Gaussian) = Gaussian(similar(P.μ), similar(P.Σ))
Base.copyto!(P::AbstractArray{<:Gaussian}, idx::Integer, el::Gaussian) = begin
P[idx] = copy(el)
P
end
function Base.copy!(dst::Gaussian, src::Gaussian)
copy!(dst.μ, src.μ)
copy!(dst.Σ, src.Σ)
return dst
end

Base.iterate(::Gaussian) = error()
Base.iterate(::Gaussian, s) = error()
Base.length(P::Gaussian) = length(mean(P))

RecursiveArrayTools.recursivecopy(P::Gaussian) = copy(P)
RecursiveArrayTools.recursivecopy!(dst::Gaussian, src::Gaussian) = copy!(dst, src)
show(io::IO, g::Gaussian) = print(io, "Gaussian($(g.μ), $(g.Σ))")
Expand All @@ -17,6 +125,11 @@ size(g::Gaussian) = size(g.μ)
ndims(g::Gaussian) = ndims(g.μ)
var(g::Gaussian) = diag(g.Σ)
std(g::Gaussian) = sqrt.(diag(g.Σ))
Base.eltype(::Type{G}) where {G<:Gaussian} = G

Base.@propagate_inbounds Base.getindex(P::Gaussian, i::Integer) =
Gaussian(P.μ[i], diag(P.Σ)[i])


############################################################################################
# `SRGaussian`: Gaussians with PDFMatrix covariances
Expand Down
23 changes: 23 additions & 0 deletions src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,29 @@ function (interp::ODEFilterPosterior)(
u = proj * x
return Gaussian(u.μ[idxs], diag(u.Σ)[idxs])
end
function (interp::ODEFilterPosterior)(
t::Real,
idxs::AbstractVector{<:Integer},
::Type{deriv},
p,
continuity,
) where {deriv}
q = interp.cache.q
dv = deriv.parameters[1]
proj = if deriv == Val{0}
interp.cache.SolProj
elseif dv <= q
interp.cache.Proj(dv)
else
throw(ArgumentError("We can only provide derivatives up to $q but you requested $dv"))
end
x = interpolate(
t, interp.ts, interp.x_filt, interp.x_smooth, interp.diffusions, interp.cache;
smoothed=interp.smooth)
u = proj * x
@assert length(u) == length(idxs)
return Gaussian(u.μ, u.Σ)
end
function (interp::ODEFilterPosterior)(
t::AbstractVector{<:Real},
idxs,
Expand Down
2 changes: 1 addition & 1 deletion test/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ using ODEProblemLibrary: prob_ode_lotkavolterra
@test_nowarn plot(sol, denseplot=false)
message = "This plot does not visualize any uncertainties"
@test_logs (:warn, message) plot(sol, idxs=(1, 2))
@test_logs (:warn, message) plot(sol, idxs=(1, 1, 2))
@test_broken plot(sol, idxs=(1, 1, 2))
@test_nowarn plot(sol, tspan=prob.tspan)
end

Expand Down

0 comments on commit f9697fb

Please sign in to comment.