Skip to content

Commit

Permalink
Shorten some code now that CovarianceStructure holds elType, d, q
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Oct 28, 2023
1 parent da5c47c commit 82b2a8f
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ function OrdinaryDiffEq.alg_cache(
matType = typeof(factorized_similar(FAC, d, d))

# Projections
Proj = projection(FAC, d, q, uElType)
Proj = projection(FAC)
E0, E1, E2 = Proj(0), Proj(1), Proj(2)
@assert f isa SciMLBase.AbstractODEFunction
SolProj = if is_secondorder_ode
Expand Down
22 changes: 6 additions & 16 deletions src/preconditioning.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
function init_preconditioner(
FAC::IsometricKroneckerCovariance,
d,
q,
::Type{elType}=typeof(1.0),
) where {elType}
P = IsometricKroneckerProduct(d, Diagonal(ones(elType, q + 1)))
PI = IsometricKroneckerProduct(d, Diagonal(ones(elType, q + 1)))
function init_preconditioner(C::IsometricKroneckerCovariance{elType}) where {elType}
P = IsometricKroneckerProduct(C.d, Diagonal(ones(elType, C.q + 1)))
PI = IsometricKroneckerProduct(C.d, Diagonal(ones(elType, C.q + 1)))
return P, PI
end
function init_preconditioner(
FAC::DenseCovariance,
d,
q,
::Type{elType}=typeof(1.0),
) where {elType}
P = kron(I(d), Diagonal(ones(elType, q + 1)))
PI = kron(I(d), Diagonal(ones(elType, q + 1)))
function init_preconditioner(C::DenseCovariance{elType}) where {elType}
P = kron(I(C.d), Diagonal(ones(elType, C.q + 1)))
PI = kron(I(C.d), Diagonal(ones(elType, C.q + 1)))
return P, PI
end

Expand Down
7 changes: 4 additions & 3 deletions src/priors/common.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
abstract type AbstractODEFilterPrior{elType} end

function initialize_preconditioner(
FAC::CovarianceStructure,
FAC::CovarianceStructure{T1},
p::AbstractODEFilterPrior{T},
dt,
) where {T}
) where {T,T1}
@assert T == T1
d, q = p.wiener_process_dimension, p.num_derivatives
P, PI = init_preconditioner(FAC, d, q, T)
P, PI = init_preconditioner(FAC)
make_preconditioner!(P, dt, d, q)
make_preconditioner_inv!(PI, dt, d, q)
return P, PI
Expand Down
22 changes: 6 additions & 16 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,17 @@ function projection(
end
return Proj
end
function projection(
::DenseCovariance,
d::Integer,
q::Integer,
::Type{elType}=typeof(1.0),
) where {elType}
projection(d, q, elType)
function projection(C::DenseCovariance{elType}) where {elType}
projection(C.d, C.q, elType)
end

function projection(
::IsometricKroneckerCovariance,
d::Integer,
q::Integer,
::Type{elType}=typeof(1.0),
) where {elType}
function projection(C::IsometricKroneckerCovariance{elType}) where {elType}
Proj(deriv) = begin
e_i = zeros(elType, q + 1, 1)
if deriv <= q
e_i = zeros(elType, C.q + 1, 1)
if deriv <= C.q
e_i[deriv+1] = 1
end
return IsometricKroneckerProduct(d, e_i')
return IsometricKroneckerProduct(C.d, e_i')
end
return Proj
end

0 comments on commit 82b2a8f

Please sign in to comment.