Skip to content

Commit

Permalink
Add covariance_factorizations.jl to clean up the if/else KRONECKER …
Browse files Browse the repository at this point in the history
…stuff
  • Loading branch information
nathanaelbosch committed Oct 20, 2023
1 parent 7085943 commit 9d2b481
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 77 deletions.
1 change: 1 addition & 0 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ vecvec2mat(x) = reduce(hcat, x)'

include("fast_linalg.jl")
include("kronecker.jl")
include("covariance_factorizations.jl")

abstract type AbstractODEFilterCache <: OrdinaryDiffEq.OrdinaryDiffEqCache end

Expand Down
9 changes: 0 additions & 9 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,6 @@ EK1(;
initialization,
)

iskronecker(alg, f) = (
alg isa EK0
&&
!(alg.diffusionmodel isa DynamicMVDiffusion ||
alg.diffusionmodel isa FixedMVDiffusion)
&& alg.prior isa IWP
&& f.mass_matrix isa UniformScaling
)

"""
ExpEK(; L, order=3, kwargs...)
Expand Down
67 changes: 17 additions & 50 deletions src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,30 +91,29 @@ function OrdinaryDiffEq.alg_cache(
d = is_secondorder_ode ? length(u[1, :]) : length(u)
D = d * (q + 1)

KRONECKER = iskronecker(alg, f)
FAC = get_covariance_factorization(alg)
if FAC isa KroneckerCovariance && !(f.mass_matrix isa UniformScaling)
error("The selected algorithm uses an efficient Kronecker-factorized implementation which is incompatible with the provided mass matrix. Try using the `EK1` instead.")
end

uType = typeof(u)
# uElType = eltype(u_vec)
uElType = uBottomEltypeNoUnits
matType = Matrix{uElType}

# Projections
Proj = projection(d, q, uElType)
Proj = projection(FAC, d, q, uElType)
E0, E1, E2 = Proj(0), Proj(1), Proj(2)
@assert f isa SciMLBase.AbstractODEFunction
SolProj = if is_secondorder_ode
if KRONECKER
if E0 isa IKP
IsoKroneckerProduct(d, [Proj(1).B; Proj(0).B])
else
SolProj = [Proj(1); Proj(0)]
end
else
Proj(0)
end
if !KRONECKER
E0, E1, E2 = Matrix(E0), Matrix(E1), Matrix(E2)
SolProj = Matrix(SolProj)
end

# Prior dynamics
prior = if alg.prior isa IWP
Expand All @@ -129,23 +128,15 @@ function OrdinaryDiffEq.alg_cache(
else
error("Invalid prior $(alg.prior)")
end
A, Q, Ah, Qh, P, PI = initialize_transition_matrices(prior, dt)
if !KRONECKER
P, PI = Diagonal(P), Diagonal(PI)
A, Q, Ah, Qh = Matrix(A), PSDMatrix(Matrix(Q.R)), Matrix(Ah), PSDMatrix(Matrix(Qh.R))
end
A, Q, Ah, Qh, P, PI = initialize_transition_matrices(FAC, prior, dt)

# Measurement Model
measurement_model = make_measurement_model(f)

# Initial State
initial_variance = ones(uElType, q + 1)
μ0 = zeros(uElType, D)
Σ0 = PSDMatrix(if KRONECKER
IsoKroneckerProduct(d, diagm(sqrt.(initial_variance)))
else
kron(I(d), diagm(sqrt.(initial_variance)))
end)
Σ0 = PSDMatrix(to_factorized_matrix(FAC, IsoKroneckerProduct(d, diagm(sqrt.(initial_variance)))))
x0 = Gaussian(μ0, Σ0)

# Diffusion Model
Expand All @@ -155,17 +146,9 @@ function OrdinaryDiffEq.alg_cache(

# Measurement model related things
R = zeros(uElType, d, d)
H = if KRONECKER
IsoKroneckerProduct(d, zeros(uElType, 1, q + 1))
else
zeros(uElType, d, D)
end
H = factorized_zeros(FAC, uElType, d, D; d, q)
v = zeros(uElType, d)
S = if KRONECKER
PSDMatrix(IsoKroneckerProduct(d, zeros(uElType, q + 1)))
else
PSDMatrix(zeros(uElType, D, d))
end
S = PSDMatrix(factorized_zeros(FAC, uElType, D, d; d, q))
measurement = Gaussian(v, S)

# Caches
Expand All @@ -174,22 +157,11 @@ function OrdinaryDiffEq.alg_cache(
pu_tmp = if !is_secondorder_ode # same dimensions as `measurement`
copy(measurement)
else # then `u` has 2d dimensions
Gaussian(
zeros(uElType, 2d),
PSDMatrix(
if KRONECKER
IsoKroneckerProduct(d, zeros(uElType, q+1, 2))
else
zeros(uElType, D, 2d)
end))
Gaussian(zeros(uElType, 2d), PSDMatrix(factorized_zeros(FAC, uElType, D, 2d; d, q)))
end
K = zeros(uElType, D, d)
G = zeros(uElType, D, D)
Smat = if KRONECKER
IsoKroneckerProduct(d, zeros(uElType, 1, 1))
else
zeros(uElType, d, d)
end
Smat = factorized_zeros(FAC, uElType, d, d; d, q)

C_dxd = zeros(uElType, d, d)
C_dxD = zeros(uElType, d, D)
Expand All @@ -198,16 +170,11 @@ function OrdinaryDiffEq.alg_cache(
C_2DxD = zeros(uElType, 2D, D)
C_3DxD = zeros(uElType, 3D, D)

backward_kernel = if KRONECKER
AffineNormalKernel(
IsoKroneckerProduct(d, zeros(uElType, q+1, q+1)),
zeros(uElType, D),
PSDMatrix(IsoKroneckerProduct(d, zeros(uElType, 2*(q+1), q+1)))
)
else
AffineNormalKernel(
zeros(uElType, D, D), zeros(uElType, D), PSDMatrix(zeros(uElType, 2D, D)))
end
backward_kernel = AffineNormalKernel(
factorized_zeros(FAC, uElType, D, D; d, q),
zeros(uElType, D),
PSDMatrix(factorized_zeros(FAC, uElType, 2D, D; d, q))
)

u_pred = copy(u)
u_filt = copy(u)
Expand Down
31 changes: 31 additions & 0 deletions src/covariance_factorizations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
abstract type CovarianceFactorization end
struct KroneckerCovariance <: CovarianceFactorization end
struct DenseCovariance <: CovarianceFactorization end

function get_covariance_factorization(alg)
if (
alg isa EK0 &&
!(
alg.diffusionmodel isa DynamicMVDiffusion ||
alg.diffusionmodel isa FixedMVDiffusion
) &&
alg.prior isa IWP
)
return KroneckerCovariance()
else
return DenseCovariance()
end
end

factorized_zeros(::KroneckerCovariance, elType, sizes...; d, q) = begin
for s in sizes
@assert s % d == 0
end
return IsoKroneckerProduct(d, zeros(elType, (s ÷ d for s in sizes)...))
end

factorized_zeros(::DenseCovariance, elType, sizes...; d, q) = zeros(elType, sizes...)

to_factorized_matrix(::DenseCovariance, M::AbstractMatrix) = Matrix(M)
to_factorized_matrix(::KroneckerCovariance, M::AbstractMatrix) = IsoKroneckerProduct(M) # probably errors
to_factorized_matrix(::KroneckerCovariance, M::IKP) = M
13 changes: 11 additions & 2 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,16 @@ Base.:*(K::IKP, a::Number) = IsoKroneckerProduct(K.ldim, K.B * a)
Base.:*(a::Number, K::IKP) = IsoKroneckerProduct(K.ldim, a * K.B)
LinearAlgebra.adjoint(A::IKP) = IsoKroneckerProduct(A.ldim, A.B')

function check_same_size(A::IKP, B::IKP)
if A.ldim != B.ldim || size(A.B) != size(B.B)
Ad, An, Am, Bd, Bn, Bm = A.ldim, size(A)..., B.ldim, size(B)...
throw(
DimensionMismatch("A has size ($Ad$An,$Ad$Am), B has size ($Bd$Bn,$Bd$Bm)"),
)
end
end
Base.:+(A::IKP, B::IKP) = begin
@assert A.ldim == B.ldim
check_same_size(A, B)
return IsoKroneckerProduct(A.ldim, A.B + B.B)
end
Base.:+(U::UniformScaling, K::IKP) = IsoKroneckerProduct(K.ldim, U + K.B)
Expand Down Expand Up @@ -68,7 +76,7 @@ Base.size(K::IKP) = (K.ldim * size(K.B, 1), K.ldim * size(K.B, 2))
# conversion
Base.convert(::Type{T}, K::IKP) where {T<:IKP} =
K isa T ? K : T(K)
function IKP{T,TB}(K::IKP) where {T,TA,TB}
function IKP{T,TB}(K::IKP) where {T,TB}
IKP(K.ldim, convert(TB, K.B))
end

Expand All @@ -80,6 +88,7 @@ reshape_no_alloc(a, dims::Tuple) =
invoke(Base._reshape, Tuple{AbstractArray,typeof(dims)}, a, dims)
# reshape_no_alloc(a::AbstractArray, dims::Tuple) = reshape(a, dims)
reshape_no_alloc(a, dims...) = reshape_no_alloc(a, Tuple(dims))
reshape_no_alloc(a::Missing, dims::Tuple) = missing

function mul_vectrick!(x::AbstractVecOrMat, A::IsoKroneckerProduct, v::AbstractVecOrMat)
N = A.B
Expand Down
5 changes: 4 additions & 1 deletion src/priors/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ See also: [`make_transition_matrices`](@ref).
[1] N. Krämer, P. Hennig: **Stable Implementation of Probabilistic ODE Solvers** (2020)
"""
function initialize_transition_matrices(p::AbstractODEFilterPrior{T}, dt) where {T}
function initialize_transition_matrices(::DenseCovariance, p::AbstractODEFilterPrior{T}, dt) where {T}
d, q = p.wiener_process_dimension, p.num_derivatives
D = d * (q + 1)
Ah, Qh = zeros(T, D, D), PSDMatrix(zeros(T, D, D))
Expand All @@ -40,6 +40,9 @@ function initialize_transition_matrices(p::AbstractODEFilterPrior{T}, dt) where
Q = copy(Qh)
return A, Q, Ah, Qh, P, PI
end
function initialize_transition_matrices(fac::CovarianceFactorization, p::AbstractODEFilterPrior, dt)
error("The chosen prior can not be implemented with a $fac factorization")
end

"""
make_transition_matrices!(cache, prior::AbstractODEFilterPrior, dt)
Expand Down
8 changes: 7 additions & 1 deletion src/priors/iwp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,19 @@ function discretize(p::IWP, dt::Real)
return A, Q
end

function initialize_transition_matrices(p::IWP{T}, dt) where {T}
function initialize_transition_matrices(::KroneckerCovariance, p::IWP{T}, dt) where {T}
A, Q = preconditioned_discretize(p)
P, PI = initialize_preconditioner(p, dt)
Ah = PI * A * P
Qh = PSDMatrix(Q.R * PI)
return A, Q, Ah, Qh, P, PI
end
function initialize_transition_matrices(::DenseCovariance, p::IWP{T}, dt) where {T}
A, Q, Ah, Qh, P, PI = initialize_transition_matrices(KroneckerCovariance(), p, dt)
P, PI = Diagonal(P), Diagonal(PI)
A, Q, Ah, Qh = Matrix(A), PSDMatrix(Matrix(Q.R)), Matrix(Ah), PSDMatrix(Matrix(Qh.R))
return A, Q, Ah, Qh, P, PI
end

function make_transition_matrices!(cache, prior::IWP, dt)
@unpack A, Q, Ah, Qh, P, PI = cache
Expand Down
28 changes: 25 additions & 3 deletions src/projection.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
function projection(d::Integer, q::Integer, ::Type{elType}=typeof(1.0)) where {elType}
function projection(
::DenseCovariance,
d::Integer,
q::Integer,
::Type{elType}=typeof(1.0),
) where {elType}
D = d * (q + 1)
Proj(deriv) = begin
e_i = zeros(elType, q+1, 1)
P = zeros(elType, d, D)
@simd ivdep for i in deriv*d+1:D+1:d*D
@inbounds P[i] = 1
end
return P
end
return Proj
end

function projection(
::KroneckerCovariance,
d::Integer,
q::Integer,
::Type{elType}=typeof(1.0),
) where {elType}
Proj(deriv) = begin
e_i = zeros(elType, q + 1, 1)
if deriv <= q
e_i[deriv+1] = 1
end
IsoKroneckerProduct(d, e_i')
return IsoKroneckerProduct(d, e_i')
end
return Proj
end
15 changes: 4 additions & 11 deletions src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ function DiffEqBase.build_solution(
true,
Val(isinplace(prob)),
)
q = cache.q

T = eltype(eltype(u))
N = length((size(prob.u0)..., length(u)))
Expand All @@ -93,18 +94,10 @@ function DiffEqBase.build_solution(
uElType = eltype(prob.u0)
D = d

KRONECKER = iskronecker(alg, prob.f)
FAC = get_covariance_factorization(alg)

pu_cov = if KRONECKER
PSDMatrix(IsoKroneckerProduct(d, zeros(uElType, D ÷ d + 1)))
else
PSDMatrix(zeros(uElType, D, d))
end
x_cov = if KRONECKER
PSDMatrix(IsoKroneckerProduct(d, zeros(uElType, D ÷ d + 1, D ÷ d + 1)))
else
PSDMatrix(zeros(uElType, D, D))
end
pu_cov = PSDMatrix(factorized_zeros(FAC, uElType, D, d; d, q))
x_cov = PSDMatrix(factorized_zeros(FAC, uElType, D, D; d, q))
pu = StructArray{Gaussian{Vector{uElType},typeof(pu_cov)}}(undef, 0)
x_filt = StructArray{Gaussian{Vector{uElType},typeof(x_cov)}}(undef, 0)
x_smooth = copy(x_filt)
Expand Down

0 comments on commit 9d2b481

Please sign in to comment.