diff --git a/Project.toml b/Project.toml index c4a9d0579..69bd3e23b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ProbNumDiffEq" uuid = "bf3e78b0-7d74-48a5-b855-9609533b56a5" authors = ["Nathanael Bosch"] -version = "0.3.1" +version = "0.3.2" [deps] DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" diff --git a/src/ProbNumDiffEq.jl b/src/ProbNumDiffEq.jl index 6095009e3..72ac45d27 100644 --- a/src/ProbNumDiffEq.jl +++ b/src/ProbNumDiffEq.jl @@ -93,7 +93,7 @@ include("initialization/common.jl") export TaylorModeInit, RungeKuttaInit include("algorithms.jl") -export EK0, EK1 +export EK0, EK1, EK1FDB include("alg_utils.jl") include("caches.jl") diff --git a/src/algorithms.jl b/src/algorithms.jl index 15234b187..05f4777be 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -50,3 +50,12 @@ Base.@kwdef struct EK1{IT} <: AbstractEK smooth::Bool = true initialization::IT = TaylorModeInit() end + +Base.@kwdef struct EK1FDB{IT} <: AbstractEK + prior::Symbol = :ibm + order::Int = 3 + diffusionmodel::Symbol = :dynamic + smooth::Bool = true + initialization::IT = TaylorModeInit() + jac_quality::Int = 1 +end diff --git a/src/caches.jl b/src/caches.jl index 9d687ddd6..ea87c59f7 100644 --- a/src/caches.jl +++ b/src/caches.jl @@ -145,6 +145,14 @@ function OrdinaryDiffEq.alg_cache( C2 = SRMatrix(zeros(uElType, D, 3D), zeros(uElType, D, D)) covmatcache = similar(G) + if alg isa EK1FDB + H = [E1; E2] + v = [v; v] + S = SRMatrix(zeros(uElType, 2d, D), zeros(uElType, 2d, 2d)) + measurement = Gaussian(v, S) + K = zeros(uElType, D, 2d) + end + diffmodel = alg.diffusionmodel == :dynamic ? DynamicDiffusion() : alg.diffusionmodel == :fixed ? FixedDiffusion() : diff --git a/src/initialization/taylormode.jl b/src/initialization/taylormode.jl index 2c4e279a6..115122056 100644 --- a/src/initialization/taylormode.jl +++ b/src/initialization/taylormode.jl @@ -3,18 +3,31 @@ function initial_update!(integ, cache, init::TaylorModeInit) @unpack u, f, p, t = integ @unpack d, x, Proj = cache q = integ.alg.order + D = d * (q + 1) @unpack x_tmp, x_tmp2, m_tmp, K1, K2 = cache f_derivatives = taylormode_get_derivatives(u, f, p, t, q) @assert length(0:q) == length(f_derivatives) + m_cache = Gaussian( + zeros(eltype(u), d), + SRMatrix(zeros(eltype(u), d, D), zeros(eltype(u), d, d)), + ) for (o, df) in zip(0:q, f_derivatives) if f isa DynamicalODEFunction @assert df isa ArrayPartition df = df[2, :] end pmat = f.mass_matrix * Proj(o) - condition_on!(x, pmat, view(df, :), m_tmp, K1, x_tmp.Σ, x_tmp2.Σ.mat) + condition_on!( + x, + pmat, + view(df, :), + m_cache, + view(K1, :, 1:d), + x_tmp.Σ, + x_tmp2.Σ.mat, + ) end end diff --git a/src/perform_step.jl b/src/perform_step.jl index 6662a6285..b5042abce 100644 --- a/src/perform_step.jl +++ b/src/perform_step.jl @@ -141,12 +141,17 @@ function OrdinaryDiffEq.perform_step!( end end -function evaluate_ode!(integ, x_pred, t, second_order::Val{false}) +function evaluate_ode!( + integ::OrdinaryDiffEq.ODEIntegrator{<:AbstractEK}, + x_pred, + t, + second_order::Val{false}, +) @unpack f, p, dt, alg = integ @unpack u_pred, du, ddu, measurement, R, H = integ.cache @assert iszero(R) - @unpack E0, E1 = integ.cache + @unpack E0, E1, E2 = integ.cache z, S = measurement.μ, measurement.Σ @@ -189,7 +194,79 @@ function evaluate_ode!(integ, x_pred, t, second_order::Val{false}) return nothing end -function evaluate_ode!(integ, x_pred, t, second_order::Val{true}) +function evaluate_ode!( + integ::OrdinaryDiffEq.ODEIntegrator{<:EK1FDB}, + x_pred, + t, + second_order::Val{false}, +) + @unpack f, p, dt, alg = integ + @unpack d, u_pred, du, ddu, measurement, R, H = integ.cache + @assert iszero(R) + + @unpack E0, E1, E2 = integ.cache + + z, S = measurement.μ, measurement.Σ + + (f.mass_matrix != I) && error("EK1FDB does not support mass-matrices right now") + + # Mean + _eval_f!(du, u_pred, p, t, f) + integ.destats.nf += 1 + # z .= MM*E1*x_pred.μ .- du + H1, H2 = view(H, 1:d, :), view(H, d+1:2d, :) + z1, z2 = view(z, 1:d), view(z, d+1:2d) + + H1 .= E1 + _matmul!(z1, H1, x_pred.μ) + z1 .-= du[:] + + # Cov + u_lin = u_pred + if !isnothing(f.jac) + _eval_f_jac!(ddu, u_lin, p, t, f) + elseif isinplace(f) + ForwardDiff.jacobian!(ddu, (du, u) -> f(du, u, p, t), du, u_lin) + else + ddu .= ForwardDiff.jacobian(u -> f(u, p, t), u_lin) + end + integ.destats.njacs += 1 + _matmul!(H1, ddu, E0, -1.0, 1.0) + + z2 .= (E2 * x_pred.μ .- ddu * du) + if integ.alg.jac_quality == 1 + # EK0-type approach + H2 .= E2 + elseif integ.alg.jac_quality == 2 + H2 .= E2 - ddu * ddu * E0 + elseif integ.alg.jac_quality == 3 + _z2(m) = begin + u_pred = E0 * m + du = zeros(eltype(m), d) + ddu = zeros(eltype(m), d, d) + _eval_f!(du, u_pred, p, t, f) + if !isnothing(f.jac) + _eval_f_jac!(ddu, u_pred, p, t, f) + elseif isinplace(f) + ForwardDiff.jacobian!(ddu, (du, u) -> f(du, u, p, t), du, u_pred) + else + ddu .= ForwardDiff.jacobian(u -> f(u, p, t), u_pred) + end + return (E2 * m .- ddu * du) + end + H2 .= ForwardDiff.jacobian(_z2, x_pred.μ) + else + error("EK1FDB's `jac_quality` has to be in [1,2,3]") + end + return nothing +end + +function evaluate_ode!( + integ::OrdinaryDiffEq.ODEIntegrator{<:AbstractEK}, + x_pred, + t, + second_order::Val{true}, +) @unpack f, p, dt, alg = integ @unpack d, u_pred, du, ddu, measurement, R, H = integ.cache @assert iszero(R) @@ -297,7 +374,7 @@ function smooth_all!(integ) end function estimate_errors(cache::GaussianODEFilterCache) - @unpack local_diffusion, Qh, H = cache + @unpack local_diffusion, Qh, H, d = cache if local_diffusion isa Real && isinf(local_diffusion) return Inf @@ -308,7 +385,7 @@ function estimate_errors(cache::GaussianODEFilterCache) if local_diffusion isa Diagonal _matmul!(L, H, sqrt.(local_diffusion) * Qh.squareroot) error_estimate = sqrt.(diag(L * L')) - return error_estimate + return view(error_estimate, 1:d) elseif local_diffusion isa Number _matmul!(L, H, Qh.squareroot) @@ -320,6 +397,6 @@ function estimate_errors(cache::GaussianODEFilterCache) # error_estimate .+= cache.measurement.μ .^ 2 error_estimate .= sqrt.(error_estimate) - return error_estimate + return view(error_estimate, 1:d) end end diff --git a/test/correctness.jl b/test/correctness.jl index 76b18f1c8..5a37f306d 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -21,12 +21,15 @@ for (prob, probname) in [ true_sol = solve(remake(prob, u0=big.(prob.u0)), Tsit5(), abstol=1e-20, reltol=1e-20) - for Alg in (EK0, EK1), + EK1FDB1(; kwargs...) = EK1FDB(; jac_quality=1, kwargs...) + EK1FDB2(; kwargs...) = EK1FDB(; jac_quality=2, kwargs...) + EK1FDB3(; kwargs...) = EK1FDB(; jac_quality=3, kwargs...) + for Alg in (EK0, EK1, EK1FDB1, EK1FDB2, EK1FDB3), diffusion in [:fixed, :dynamic, :fixedMV, :dynamicMV], init in [TaylorModeInit(), RungeKuttaInit()], q in [2, 3, 5] - if Alg == EK1 && diffusion in (:fixedMV, :dynamicMV) + if diffusion in (:fixedMV, :dynamicMV) && Alg != EK0 continue end @@ -55,12 +58,15 @@ for (prob, probname) in [ solve(remake(prob, u0=big.(prob.u0)), Tsit5(), abstol=1e-20, reltol=1e-20) true_dense_vals = true_sol.(t_eval) - for Alg in (EK0, EK1), + EK1FDB1(; kwargs...) = EK1FDB(; jac_quality=1, kwargs...) + EK1FDB2(; kwargs...) = EK1FDB(; jac_quality=2, kwargs...) + EK1FDB3(; kwargs...) = EK1FDB(; jac_quality=3, kwargs...) + for Alg in (EK0, EK1, EK1FDB1, EK1FDB2, EK1FDB3), diffusion in [:fixed, :dynamic, :fixedMV, :dynamicMV], init in [TaylorModeInit(), RungeKuttaInit()], q in [2, 3, 5] - if Alg == EK1 && diffusion in (:fixedMV, :dynamicMV) + if diffusion in (:fixedMV, :dynamicMV) && Alg != EK0 continue end diff --git a/test/specific_problems.jl b/test/specific_problems.jl index 8cf3eadc3..7cd6e9331 100644 --- a/test/specific_problems.jl +++ b/test/specific_problems.jl @@ -66,11 +66,47 @@ end @test sol isa ProbNumDiffEq.ProbODESolution end -@testset "OOP problem definition" begin - prob = ODEProblem((u, p, t) -> ([p[1] * u[1] .* (1 .- u[1])]), [1e-1], (0.0, 5), [3.0]) - @test solve(prob, EK0(order=4)) isa ProbNumDiffEq.ProbODESolution - prob = ProbNumDiffEq.remake_prob_with_jac(prob) - @test solve(prob, EK1(order=4)) isa ProbNumDiffEq.ProbODESolution +@testset "IIP problem" begin + f(du, u, p, t) = (du[1] = p[1] * u[1] .* (1 .- u[1])) + prob = ODEProblem(f, [1e-1], (0.0, 5), [3.0]) + @testset "without jacobian" begin + @test solve(prob, EK0(order=4)) isa ProbNumDiffEq.ProbODESolution + # first without defined jac + @test solve(prob, EK1(order=4)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1FDB(order=4, jac_quality=1)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1FDB(order=4, jac_quality=2)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1FDB(order=4, jac_quality=3)) isa ProbNumDiffEq.ProbODESolution + end + @testset "with jacobian" begin + # now with defined jac + prob = ProbNumDiffEq.remake_prob_with_jac(prob) + @test solve(prob, EK1(order=4)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1FDB(order=4, jac_quality=1)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1FDB(order=4, jac_quality=2)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1FDB(order=4, jac_quality=3)) isa ProbNumDiffEq.ProbODESolution + end +end + +@testset "OOP problem" begin + f(u, p, t) = ([p[1] * u[1] .* (1 .- u[1])]) + prob = ODEProblem(f, [1e-1], (0.0, 5), [3.0]) + @testset "without jacobian" begin + # first without defined jac + @test solve(prob, EK0(order=4)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1(order=4)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1FDB(order=4, jac_quality=1)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1FDB(order=4, jac_quality=2)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1FDB(order=4, jac_quality=3)) isa ProbNumDiffEq.ProbODESolution + end + @testset "with jacobian" begin + # now with defined jac + prob = ProbNumDiffEq.remake_prob_with_jac(prob) + @test solve(prob, EK0(order=4)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1(order=4)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1FDB(order=4, jac_quality=1)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1FDB(order=4, jac_quality=2)) isa ProbNumDiffEq.ProbODESolution + @test solve(prob, EK1FDB(order=4, jac_quality=3)) isa ProbNumDiffEq.ProbODESolution + end end @testset "Callback: Harmonic Oscillator with condition on E=2" begin