From 7c08312c12e910dacf04c7b17ce082082a5c4111 Mon Sep 17 00:00:00 2001 From: Nathanael Bosch Date: Wed, 11 Oct 2023 09:51:42 +0200 Subject: [PATCH] Fix ClassicSolverInit and start restoring the previous MM behavior --- src/initialization/classicsolverinit.jl | 31 +++++++++++-- src/initialization/taylormode.jl | 19 +++----- test/state_init.jl | 62 ++++++++++++++++--------- 3 files changed, 71 insertions(+), 41 deletions(-) diff --git a/src/initialization/classicsolverinit.jl b/src/initialization/classicsolverinit.jl index 653a0c496..ca359f278 100644 --- a/src/initialization/classicsolverinit.jl +++ b/src/initialization/classicsolverinit.jl @@ -8,15 +8,29 @@ function initial_update!(integ, cache, ::ClassicSolverInit) end @unpack ddu, du, x_tmp, m_tmp, K1 = cache + @unpack x_tmp, K1, C_Dxd, C_DxD, C_dxd, measurement = cache # Initialize on u0; taking special care for DynamicalODEProblems is_secondorder = integ.f isa DynamicalODEFunction _u = is_secondorder ? view(u.x[2], :) : view(u, :) - Mcache = cache.C_DxD - condition_on!(x, Proj(0), _u, cache) + # condition_on!(x, Proj(0), _u, cache) + begin + H = Proj(0) + measurement.μ .= H*x.μ - _u + fast_X_A_Xt!(measurement.Σ, x.Σ, H) + copy!(x_tmp, x) + update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd) + end is_secondorder ? f.f1(du, u.x[1], u.x[2], p, t) : f(du, u, p, t) integ.stats.nf += 1 - condition_on!(x, Proj(1), view(du, :), cache) + # condition_on!(x, Proj(1), view(du, :), cache) + begin + H = Proj(1) + measurement.μ .= H*x.μ - view(du, :) + fast_X_A_Xt!(measurement.Σ, x.Σ, H) + copy!(x_tmp, x) + update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd) + end if q < 2 return @@ -39,7 +53,14 @@ function initial_update!(integ, cache, ::ClassicSolverInit) ForwardDiff.jacobian!(ddu, (du, u) -> _f(du, u, p, t), du, u) end ddfddu = ddu * view(du, :) + view(dfdt, :) - condition_on!(x, Proj(2), ddfddu, cache) + # condition_on!(x, Proj(2), ddfddu, cache) + begin + H = Proj(2) + measurement.μ .= H*x.μ - ddfddu + fast_X_A_Xt!(measurement.Σ, x.Σ, H) + copy!(x_tmp, x) + update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd) + end if q < 3 return end @@ -112,7 +133,7 @@ function rk_init_improve(cache::AbstractODEFilterCache, ts, us, dt) H = cache.E0 * PI measurement.μ .= H * x_pred.μ .- u - X_A_Xt!(measurement.Σ, x_pred.Σ, H) + fast_X_A_Xt!(measurement.Σ, x_pred.Σ, H) update!(x_filt, x_pred, measurement, H, K1, C_Dxd, C_DxD, C_dxd) push!(filts, copy(x_filt)) diff --git a/src/initialization/taylormode.jl b/src/initialization/taylormode.jl index baf85ac26..cffa02329 100644 --- a/src/initialization/taylormode.jl +++ b/src/initialization/taylormode.jl @@ -3,7 +3,7 @@ function initial_update!(integ, cache, init::TaylorModeInit) @unpack d, q, q, x, Proj = cache D = d * (q + 1) - @unpack x_tmp, K1 = cache + @unpack x_tmp, K1, C_Dxd, C_DxD, C_dxd, measurement = cache if size(K1, 2) != d K1 = K1[:, 1:d] end @@ -16,28 +16,21 @@ function initial_update!(integ, cache, init::TaylorModeInit) f_derivatives = taylormode_get_derivatives(u, f, p, t, q) integ.stats.nf += q @assert length(0:q) == length(f_derivatives) - m_cache = Gaussian(zeros(eltype(u), d), PSDMatrix(zeros(eltype(u), D, d))) for (o, df) in zip(0:q, f_derivatives) if f isa DynamicalODEFunction @assert df isa ArrayPartition df = df[2, :] end - # pmat = f.mass_matrix * Proj(o) - @assert f.mass_matrix === I - pmat = Proj(o) if !(df isa AbstractVector) df = df[:] end - # condition_on!(x, pmat, df, cache) - x.μ[(o+1):(q+1):end] .= df - end - if x.Σ.R isa Kronecker.KroneckerProduct - x.Σ.R.A .= 0 - x.Σ.R.B .= 0 - else - x.Σ.R .= 0 + H = f.mass_matrix * Proj(o) + measurement.μ .= H*x.μ - df + fast_X_A_Xt!(measurement.Σ, x.Σ, H) + copy!(x_tmp, x) + update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd) end end diff --git a/test/state_init.jl b/test/state_init.jl index 1c6c46b72..b16f9b0a1 100644 --- a/test/state_init.jl +++ b/test/state_init.jl @@ -6,33 +6,34 @@ using Test import ODEProblemLibrary: prob_ode_fitzhughnagumo, prob_ode_pleiades -d = 2 -q = 6 -D = d * (q + 1) +@testset "Testproblem" begin + d = 2 + q = 6 + D = d * (q + 1) -a, b = 1.1, -0.5 -f(u, p, t) = [a * u[1], b * u[2]] -u0 = [0.1, 1.0] -tspan = (0.0, 5.0) -t0, T = tspan -prob = ODEProblem(f, u0, tspan) -p = prob.p + a, b = 1.1, -0.5 + f(u, p, t) = [a * u[1], b * u[2]] + u0 = [0.1, 1.0] + tspan = (0.0, 5.0) + t0, T = tspan + prob = ODEProblem(f, u0, tspan) + p = prob.p -# True Solutions and derivatives -u(t) = [a^0 * u0[1] * exp(a * t), u0[2] * exp(b * t)] -du(t) = [a^1 * u0[1] * exp(a * t), b * u0[2] * exp(b * t)] -ddu(t) = [a^2 * u0[1] * exp(a * t), (b)^2 * u0[2] * exp(b * t)] -dddu(t) = [a^3 * u0[1] * exp(a * t), (b)^3 * u0[2] * exp(b * t)] -ddddu(t) = [a^4 * u0[1] * exp(a * t), (b)^4 * u0[2] * exp(b * t)] -dddddu(t) = [a^5 * u0[1] * exp(a * t), (b)^5 * u0[2] * exp(b * t)] -ddddddu(t) = [a^6 * u0[1] * exp(a * t), (b)^6 * u0[2] * exp(b * t)] -true_init_states = [u(t0); du(t0); ddu(t0); dddu(t0); ddddu(t0); dddddu(t0); ddddddu(t0)] + # True Solutions and derivatives + u(t) = [a^0 * u0[1] * exp(a * t), u0[2] * exp(b * t)] + du(t) = [a^1 * u0[1] * exp(a * t), b * u0[2] * exp(b * t)] + ddu(t) = [a^2 * u0[1] * exp(a * t), (b)^2 * u0[2] * exp(b * t)] + dddu(t) = [a^3 * u0[1] * exp(a * t), (b)^3 * u0[2] * exp(b * t)] + ddddu(t) = [a^4 * u0[1] * exp(a * t), (b)^4 * u0[2] * exp(b * t)] + dddddu(t) = [a^5 * u0[1] * exp(a * t), (b)^5 * u0[2] * exp(b * t)] + ddddddu(t) = [a^6 * u0[1] * exp(a * t), (b)^6 * u0[2] * exp(b * t)] + true_init_states = + [u(t0); du(t0); ddu(t0); dddu(t0); ddddu(t0); dddddu(t0); ddddddu(t0)] -@testset "Taylormode initialization" begin - @testset "IIP" begin - f!(du, u, p, t) = (du .= f(u, p, t)) - prob = ODEProblem{true,true}(f!, u0, tspan) + f!(du, u, p, t) = (du .= f(u, p, t)) + prob = ODEProblem{true,true}(f!, u0, tspan) + @testset "`taylormode_get_derivatives`" begin dfs = ProbNumDiffEq.taylormode_get_derivatives( prob.u0, prob.f, @@ -43,6 +44,21 @@ true_init_states = [u(t0); du(t0); ddu(t0); dddu(t0); ddddu(t0); dddddu(t0); ddd @test length(dfs) == q + 1 @test true_init_states ≈ vcat(dfs...) end + + @testset "Taylormode: `initial_update!`" begin + integ = init(prob, EK0(order=q)) + ProbNumDiffEq.initial_update!(integ, integ.cache, TaylorModeInit()) + x = integ.cache.x + @test reshape(x.μ, :, 2)'[:] ≈ true_init_states + end + + @testset "Low-order exact init via ClassiSolverInit: `initial_update!`" begin + _q = 2 + integ = init(prob, EK0(order=_q, initialization=ClassicSolverInit(init_on_ddu=true))) + ProbNumDiffEq.initial_update!(integ, integ.cache, integ.alg.initialization) + x = integ.cache.x + @test reshape(x.μ, :, 2)'[:] ≈ true_init_states[1:(_q+1)*d] + end end @testset "Compare TaylorModeInit and ClassicSolverInit" begin