Skip to content

Commit

Permalink
Refactor things a bit that I should have done in the last PR
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Nov 10, 2023
1 parent 451dedc commit b574db9
Showing 1 changed file with 31 additions and 28 deletions.
59 changes: 31 additions & 28 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,40 +37,43 @@ function solution_space_projection(C::CovarianceStructure, is_secondorder_ode)
end
end

struct SecondOrderODESolutionProjector{T,FAC,M} <: AbstractMatrix{T}
function solution_space_projection(C::IsometricKroneckerCovariance, is_secondorder_ode)
Proj = projection(C)
if is_secondorder_ode
return KroneckerSecondOrderODESolutionProjector(C)
else
return Proj(0)
end
end

struct KroneckerSecondOrderODESolutionProjector{T,FAC,M,M2} <: AbstractMatrix{T}
covariance_structure::FAC
E0B::M
E1B::M
E0::M
E1::M
SolProjB::M2
end
function SecondOrderODESolutionProjector(C::IsometricKroneckerCovariance{T}) where {T}
function KroneckerSecondOrderODESolutionProjector(
C::IsometricKroneckerCovariance{T},
) where {T}
Proj = projection(C)
E0B, E1B = Proj(0).B, Proj(1).B
return SecondOrderODESolutionProjector{T,typeof(C),typeof(E0B)}(C, E0B, E1B)
E0, E1 = Proj(0), Proj(1)
SolProjB = [E1.B; E0.B]
return KroneckerSecondOrderODESolutionProjector{
T,typeof(C),typeof(E0),typeof(SolProjB),
}(
C, E0, E1, SolProjB,
)
end
function _gaussian_mul!(
g_out::SRGaussian, M::SecondOrderODESolutionProjector, g_in::SRGaussian)
@unpack d = M.covariance_structure
E0 = IsometricKroneckerProduct(d, M.E0B)
E1 = IsometricKroneckerProduct(d, M.E1B)
_matmul!(view(g_out.μ, 1:d), E1, g_in.μ)
_matmul!(view(g_out.μ, d+1:2d), E0, g_in.μ)
_matmul!(g_out.Σ.R.A, g_in.Σ.R.B, [M.E1B; M.E0B]')
g_out::SRGaussian, M::KroneckerSecondOrderODESolutionProjector, g_in::SRGaussian)
d = M.covariance_structure.d
_matmul!(view(g_out.μ, 1:d), M.E1, g_in.μ)
_matmul!(view(g_out.μ, d+1:2d), M.E0, g_in.μ)
_matmul!(g_out.Σ.R.A, g_in.Σ.R.B, M.SolProjB')
return g_out
end
function Base.:*(M::SecondOrderODESolutionProjector, x::SRGaussian)
@unpack d = M.covariance_structure
E0 = IsometricKroneckerProduct(d, M.E0B)
E1 = IsometricKroneckerProduct(d, M.E1B)
μ = [E1 * x.μ; E0 * x.μ]
Σ = PSDMatrix([x.Σ.R * E1' x.Σ.R * E0'])
function Base.:*(M::KroneckerSecondOrderODESolutionProjector, x::SRGaussian)
μ = [M.E1 * x.μ; M.E0 * x.μ]
Σ = PSDMatrix([x.Σ.R * M.E1' x.Σ.R * M.E0'])
return Gaussian(μ, Σ)
end

function solution_space_projection(C::IsometricKroneckerCovariance, is_secondorder_ode)
Proj = projection(C)
if is_secondorder_ode
return SecondOrderODESolutionProjector(C)
else
return Proj(0)
end
end

0 comments on commit b574db9

Please sign in to comment.