Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Oct 26, 2023
1 parent 3810b76 commit 8dffaac
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 34 deletions.
85 changes: 56 additions & 29 deletions test/core/filtering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Test
using ProbNumDiffEq
using LinearAlgebra
import ProbNumDiffEq: IsometricKroneckerProduct
import ProbNumDiffEq as PNDE

@testset "PREDICT" begin
# Setup
Expand All @@ -25,8 +26,13 @@ import ProbNumDiffEq: IsometricKroneckerProduct
x_curr = Gaussian(m, P)
x_out = copy(x_curr)

C_DxD = zeros(d, d)
C_2DxD = zeros(2d, d)
C_3DxD = zeros(3d, d)

_fstr(F) = F ? "Kronecker" : "None"
@testset "Factorization: $(_fstr(KRONECKER))" for KRONECKER in (false, true)
FAC = KRONECKER ? PNDE.IsometricKroneckerCovariance() : PNDE.DenseCovariance()
if KRONECKER
K = 2
m = kron(ones(K), m)
Expand All @@ -42,6 +48,10 @@ import ProbNumDiffEq: IsometricKroneckerProduct

x_curr = Gaussian(m, P)
x_out = copy(x_curr)

C_DxD = IsometricKroneckerProduct(K, C_DxD)
C_2DxD = IsometricKroneckerProduct(K, C_2DxD)
C_3DxD = IsometricKroneckerProduct(K, C_3DxD)
end

@testset "predict" begin
Expand All @@ -54,7 +64,7 @@ import ProbNumDiffEq: IsometricKroneckerProduct
x_curr = Gaussian(m, PSDMatrix(P_R))
x_out = copy(x_curr)
Q_SR = PSDMatrix(Q_R)
ProbNumDiffEq.predict!(x_out, x_curr, A, Q_SR, zeros(d, d), zeros(2d, d))
ProbNumDiffEq.predict!(x_out, x_curr, A, Q_SR, C_DxD, C_2DxD)
@test m_p x_out.μ
@test P_p Matrix(x_out.Σ)
end
Expand All @@ -63,7 +73,7 @@ import ProbNumDiffEq: IsometricKroneckerProduct
x_curr = Gaussian(m, PSDMatrix(P_R))
x_out = copy(x_curr)
Q_SR = PSDMatrix(Q_R)
ProbNumDiffEq.predict!(x_out, x_curr, A, Q_SR, zeros(d, d), zeros(2d, d), 0)
ProbNumDiffEq.predict!(x_out, x_curr, A, Q_SR, C_DxD, C_2DxD, 0)
@test m_p x_out.μ
@test Matrix(x_out.Σ) Matrix(X_A_Xt(x_curr.Σ, A))
end
Expand All @@ -78,13 +88,8 @@ import ProbNumDiffEq: IsometricKroneckerProduct
PSDMatrix([Q_R; zero(Q_R)])
end
K = ProbNumDiffEq.AffineNormalKernel(A, Q_SR)
ProbNumDiffEq.marginalize!(
x_out,
x_curr,
K;
C_DxD=zeros(d, d),
C_3DxD=zeros(3d, d),
)
T = eltype(m)
ProbNumDiffEq.marginalize!(x_out, x_curr, K; C_DxD, C_3DxD)
@test m_p x_out.μ
@test P_p Matrix(x_out.Σ)
end
Expand Down Expand Up @@ -119,6 +124,12 @@ end
x_out = copy(x_pred)
measurement = Gaussian(z, S)

C_dxd = zeros(o, o)
C_Dxd = zeros(d, o)
C_DxD = zeros(d, d)
C_2DxD = zeros(2d, d)
C_3DxD = zeros(3d, d)

_fstr(F) = F ? "Kronecker" : "None"
@testset "Factorization: $(_fstr(KRONECKER))" for KRONECKER in (false, true)
if KRONECKER
Expand All @@ -134,6 +145,12 @@ end
x_pred = Gaussian(m_p, P_p)
x_out = copy(x_pred)
measurement = Gaussian(z, S)

C_dxd = IsometricKroneckerProduct(1, C_dxd)
C_Dxd = IsometricKroneckerProduct(1, C_Dxd)
C_DxD = IsometricKroneckerProduct(1, C_DxD)
C_2DxD = IsometricKroneckerProduct(1, C_2DxD)
C_3DxD = IsometricKroneckerProduct(1, C_3DxD)
end

@testset "update" begin
Expand Down Expand Up @@ -162,12 +179,12 @@ end
# @test P ≈ Matrix(x_out.Σ)
# end
@testset "PSDMatrix" begin
K_cache = copy(K)
K2_cache = copy(K)
M_cache = zeros(d, d)
K_cache = copy(C_Dxd)
K2_cache = copy(C_Dxd)
M_cache = C_DxD
S = measurement.Σ
msmnt = Gaussian(measurement.μ, PSDMatrix(SR))
O_cache = zeros(o, o)
O_cache = C_dxd
x_pred = Gaussian(x_pred.μ, PSDMatrix(P_p_R))
x_out = copy(x_pred)
ProbNumDiffEq.update!(
Expand All @@ -184,12 +201,12 @@ end
@test P Matrix(x_out.Σ)
end
@testset "Zero predicted covariance" begin
K_cache = copy(K)
K2_cache = copy(K)
M_cache = zeros(d, d)
K_cache = copy(C_Dxd)
K2_cache = copy(C_Dxd)
M_cache = C_DxD
S = measurement.Σ
msmnt = Gaussian(measurement.μ, PSDMatrix(SR))
O_cache = zeros(o, o)
O_cache = C_dxd
x_pred = Gaussian(x_pred.μ, PSDMatrix(zero(P_p_R)))
x_out = copy(x_pred)
ProbNumDiffEq.update!(
Expand Down Expand Up @@ -244,7 +261,7 @@ end
P, P_s = P_R'P_R, P_s_R'P_s_R

A = rand(d, d)
Q_R = Matrix(UpperTriangular(rand(d, d)))
Q_R = Matrix(UpperTriangular(rand(d, d))) + I
Q = Q_R'Q_R
Q_SR = PSDMatrix(Q_R)

Expand Down Expand Up @@ -287,7 +304,7 @@ end
x_next = Gaussian(m_s, P_s)

m_smoothed = kron(ones(K), m_smoothed)
P_smoothed = kron(I(K), P_smoothed)
P_smoothed = IsometricKroneckerProduct(K, P_smoothed)
x_smoothed = Gaussian(m_smoothed, P_smoothed)
end

Expand All @@ -304,16 +321,22 @@ end
@test P_smoothed Matrix(x_out.Σ)
end
@testset "smooth!" begin
_d = KRONECKER ? K * d : d
_d = d
x_curr_psd = Gaussian(m, PSDMatrix(P_R)) |> copy
x_next_psd = Gaussian(m_s, PSDMatrix(P_s_R)) |> copy
cache = (
x_pred=copy(x_curr_psd),
G1=zeros(_d, _d),
C_DxD=zeros(_d, _d),
C_2DxD=zeros(2_d, _d),
C_3DxD=zeros(3_d, _d),
)
cache = if !KRONECKER
(x_pred=copy(x_curr_psd),
G1=zeros(_d, _d),
C_DxD=zeros(_d, _d),
C_2DxD=zeros(2_d, _d),
C_3DxD=zeros(3_d, _d))
else
(x_pred=copy(x_curr_psd),
G1=IsometricKroneckerProduct(K, zeros(_d, _d)),
C_DxD=IsometricKroneckerProduct(K, zeros(_d, _d)),
C_2DxD=IsometricKroneckerProduct(K, zeros(2_d, _d)),
C_3DxD=IsometricKroneckerProduct(K, zeros(3_d, _d)))
end
ProbNumDiffEq.smooth!(
x_curr_psd,
x_next_psd,
Expand All @@ -340,7 +363,7 @@ end
x_next_smoothed = Gaussian(m_s, PSDMatrix(P_s_R)) |> copy

C_DxD = if KRONECKER
zeros(K * d, K * d)
IsometricKroneckerProduct(K, zeros(d, d))
else
zeros(d, d)
end
Expand All @@ -354,7 +377,11 @@ end
@test K_backward.b b
@test Matrix(K_backward.C) Λ

C_3DxD = zeros(3d, d)
C_3DxD = if KRONECKER
IsometricKroneckerProduct(K, zeros(3d, d))
else
zeros(3d, d)
end
ProbNumDiffEq.marginalize_mean!(x_curr.μ, x_next_smoothed.μ, K_backward)
ProbNumDiffEq.marginalize_cov!(
x_curr.Σ,
Expand Down
15 changes: 10 additions & 5 deletions test/state_init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,18 @@ import ODEProblemLibrary: prob_ode_fitzhughnagumo, prob_ode_pleiades
end

@testset "Low-order exact init via ClassiSolverInit: `initial_update!`" begin
_q = 2
@test_nowarn init(prob, EK0(order=1, initialization=ClassicSolverInit(init_on_ddu=true)))
@test_nowarn init(prob, EK0(order=2, initialization=ClassicSolverInit(init_on_ddu=false)))
@test_broken init(prob, EK0(order=2, initialization=ClassicSolverInit(init_on_ddu=true)))

@test_nowarn init(prob, EK1(order=1, initialization=ClassicSolverInit(init_on_ddu=true)))
@test_nowarn init(prob, EK1(order=2, initialization=ClassicSolverInit(init_on_ddu=true)))
integ =
init(prob, EK0(order=_q, initialization=ClassicSolverInit(init_on_ddu=true)))
init(prob, EK1(order=2, initialization=ClassicSolverInit(init_on_ddu=true)))
ProbNumDiffEq.initial_update!(integ, integ.cache, integ.alg.initialization)
x = integ.cache.x
@test reshape(x.μ, :, 2)'[:] true_init_states[1:(_q+1)*d]
@test reshape(x.μ, :, 2)'[:] true_init_states[1:(2+1)*d]

end
end

Expand All @@ -71,8 +77,7 @@ end
Proj1 = integ1.cache.Proj

@testset "Order $o" for o in (1, 2, 3, 4, 5)
integ2 =
init(prob, EK0(order=o, initialization=ClassicSolverInit(init_on_ddu=true)))
integ2 = init(prob, EK1(order=o, initialization=ClassicSolverInit(init_on_ddu=true)))
rk_init = integ2.cache.x.μ
Proj2 = integ2.cache.Proj

Expand Down

0 comments on commit 8dffaac

Please sign in to comment.