Skip to content

Commit

Permalink
Fix ClassicSolverInit and start restoring the previous MM behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Oct 11, 2023
1 parent d3c7d45 commit 7c08312
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 41 deletions.
31 changes: 26 additions & 5 deletions src/initialization/classicsolverinit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,29 @@ function initial_update!(integ, cache, ::ClassicSolverInit)
end

@unpack ddu, du, x_tmp, m_tmp, K1 = cache
@unpack x_tmp, K1, C_Dxd, C_DxD, C_dxd, measurement = cache

# Initialize on u0; taking special care for DynamicalODEProblems
is_secondorder = integ.f isa DynamicalODEFunction
_u = is_secondorder ? view(u.x[2], :) : view(u, :)
Mcache = cache.C_DxD
condition_on!(x, Proj(0), _u, cache)
# condition_on!(x, Proj(0), _u, cache)
begin
H = Proj(0)
measurement.μ .= H*x.μ - _u
fast_X_A_Xt!(measurement.Σ, x.Σ, H)
copy!(x_tmp, x)
update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd)
end
is_secondorder ? f.f1(du, u.x[1], u.x[2], p, t) : f(du, u, p, t)
integ.stats.nf += 1
condition_on!(x, Proj(1), view(du, :), cache)
# condition_on!(x, Proj(1), view(du, :), cache)
begin
H = Proj(1)
measurement.μ .= H*x.μ - view(du, :)
fast_X_A_Xt!(measurement.Σ, x.Σ, H)
copy!(x_tmp, x)
update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd)
end

if q < 2
return
Expand All @@ -39,7 +53,14 @@ function initial_update!(integ, cache, ::ClassicSolverInit)
ForwardDiff.jacobian!(ddu, (du, u) -> _f(du, u, p, t), du, u)
end
ddfddu = ddu * view(du, :) + view(dfdt, :)
condition_on!(x, Proj(2), ddfddu, cache)
# condition_on!(x, Proj(2), ddfddu, cache)
begin
H = Proj(2)
measurement.μ .= H*x.μ - ddfddu
fast_X_A_Xt!(measurement.Σ, x.Σ, H)
copy!(x_tmp, x)
update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd)
end
if q < 3
return
end
Expand Down Expand Up @@ -112,7 +133,7 @@ function rk_init_improve(cache::AbstractODEFilterCache, ts, us, dt)

H = cache.E0 * PI
measurement.μ .= H * x_pred.μ .- u
X_A_Xt!(measurement.Σ, x_pred.Σ, H)
fast_X_A_Xt!(measurement.Σ, x_pred.Σ, H)

update!(x_filt, x_pred, measurement, H, K1, C_Dxd, C_DxD, C_dxd)
push!(filts, copy(x_filt))
Expand Down
19 changes: 6 additions & 13 deletions src/initialization/taylormode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function initial_update!(integ, cache, init::TaylorModeInit)
@unpack d, q, q, x, Proj = cache
D = d * (q + 1)

@unpack x_tmp, K1 = cache
@unpack x_tmp, K1, C_Dxd, C_DxD, C_dxd, measurement = cache
if size(K1, 2) != d
K1 = K1[:, 1:d]
end
Expand All @@ -16,28 +16,21 @@ function initial_update!(integ, cache, init::TaylorModeInit)
f_derivatives = taylormode_get_derivatives(u, f, p, t, q)
integ.stats.nf += q
@assert length(0:q) == length(f_derivatives)
m_cache = Gaussian(zeros(eltype(u), d), PSDMatrix(zeros(eltype(u), D, d)))
for (o, df) in zip(0:q, f_derivatives)
if f isa DynamicalODEFunction
@assert df isa ArrayPartition
df = df[2, :]
end
# pmat = f.mass_matrix * Proj(o)
@assert f.mass_matrix === I
pmat = Proj(o)

if !(df isa AbstractVector)
df = df[:]
end

# condition_on!(x, pmat, df, cache)
x.μ[(o+1):(q+1):end] .= df
end
if x.Σ.R isa Kronecker.KroneckerProduct
x.Σ.R.A .= 0
x.Σ.R.B .= 0
else
x.Σ.R .= 0
H = f.mass_matrix * Proj(o)
measurement.μ .= H*x.μ - df
fast_X_A_Xt!(measurement.Σ, x.Σ, H)
copy!(x_tmp, x)
update!(x, x_tmp, measurement, H, K1, C_Dxd, C_DxD, C_dxd)
end
end

Expand Down
62 changes: 39 additions & 23 deletions test/state_init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,34 @@ using Test

import ODEProblemLibrary: prob_ode_fitzhughnagumo, prob_ode_pleiades

d = 2
q = 6
D = d * (q + 1)
@testset "Testproblem" begin
d = 2
q = 6
D = d * (q + 1)

a, b = 1.1, -0.5
f(u, p, t) = [a * u[1], b * u[2]]
u0 = [0.1, 1.0]
tspan = (0.0, 5.0)
t0, T = tspan
prob = ODEProblem(f, u0, tspan)
p = prob.p
a, b = 1.1, -0.5
f(u, p, t) = [a * u[1], b * u[2]]
u0 = [0.1, 1.0]
tspan = (0.0, 5.0)
t0, T = tspan
prob = ODEProblem(f, u0, tspan)
p = prob.p

# True Solutions and derivatives
u(t) = [a^0 * u0[1] * exp(a * t), u0[2] * exp(b * t)]
du(t) = [a^1 * u0[1] * exp(a * t), b * u0[2] * exp(b * t)]
ddu(t) = [a^2 * u0[1] * exp(a * t), (b)^2 * u0[2] * exp(b * t)]
dddu(t) = [a^3 * u0[1] * exp(a * t), (b)^3 * u0[2] * exp(b * t)]
ddddu(t) = [a^4 * u0[1] * exp(a * t), (b)^4 * u0[2] * exp(b * t)]
dddddu(t) = [a^5 * u0[1] * exp(a * t), (b)^5 * u0[2] * exp(b * t)]
ddddddu(t) = [a^6 * u0[1] * exp(a * t), (b)^6 * u0[2] * exp(b * t)]
true_init_states = [u(t0); du(t0); ddu(t0); dddu(t0); ddddu(t0); dddddu(t0); ddddddu(t0)]
# True Solutions and derivatives
u(t) = [a^0 * u0[1] * exp(a * t), u0[2] * exp(b * t)]
du(t) = [a^1 * u0[1] * exp(a * t), b * u0[2] * exp(b * t)]
ddu(t) = [a^2 * u0[1] * exp(a * t), (b)^2 * u0[2] * exp(b * t)]
dddu(t) = [a^3 * u0[1] * exp(a * t), (b)^3 * u0[2] * exp(b * t)]
ddddu(t) = [a^4 * u0[1] * exp(a * t), (b)^4 * u0[2] * exp(b * t)]
dddddu(t) = [a^5 * u0[1] * exp(a * t), (b)^5 * u0[2] * exp(b * t)]
ddddddu(t) = [a^6 * u0[1] * exp(a * t), (b)^6 * u0[2] * exp(b * t)]
true_init_states =
[u(t0); du(t0); ddu(t0); dddu(t0); ddddu(t0); dddddu(t0); ddddddu(t0)]

@testset "Taylormode initialization" begin
@testset "IIP" begin
f!(du, u, p, t) = (du .= f(u, p, t))
prob = ODEProblem{true,true}(f!, u0, tspan)
f!(du, u, p, t) = (du .= f(u, p, t))
prob = ODEProblem{true,true}(f!, u0, tspan)

@testset "`taylormode_get_derivatives`" begin
dfs = ProbNumDiffEq.taylormode_get_derivatives(
prob.u0,
prob.f,
Expand All @@ -43,6 +44,21 @@ true_init_states = [u(t0); du(t0); ddu(t0); dddu(t0); ddddu(t0); dddddu(t0); ddd
@test length(dfs) == q + 1
@test true_init_states vcat(dfs...)
end

@testset "Taylormode: `initial_update!`" begin
integ = init(prob, EK0(order=q))
ProbNumDiffEq.initial_update!(integ, integ.cache, TaylorModeInit())
x = integ.cache.x
@test reshape(x.μ, :, 2)'[:] true_init_states
end

@testset "Low-order exact init via ClassiSolverInit: `initial_update!`" begin
_q = 2
integ = init(prob, EK0(order=_q, initialization=ClassicSolverInit(init_on_ddu=true)))
ProbNumDiffEq.initial_update!(integ, integ.cache, integ.alg.initialization)
x = integ.cache.x
@test reshape(x.μ, :, 2)'[:] true_init_states[1:(_q+1)*d]
end
end

@testset "Compare TaylorModeInit and ClassicSolverInit" begin
Expand Down

0 comments on commit 7c08312

Please sign in to comment.