Skip to content

Commit

Permalink
MPRK22 supports StaticArrays & examples/test
Browse files Browse the repository at this point in the history
  • Loading branch information
SKopecz committed Mar 22, 2024
1 parent f326abb commit 325cee0
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 37 deletions.
102 changes: 86 additions & 16 deletions examples/04_example_problemlibrary.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Install packages
import Pkg
Pkg.activate(@__DIR__)
#Pkg.develop(path = dirname(@__DIR__))
#Pkg.instantiate()
Pkg.develop(path = dirname(@__DIR__))
Pkg.instantiate()

# load packages
using PositiveIntegrators
Expand All @@ -23,11 +23,18 @@ f6(t, u1, u2, u3, u4, u5, u6) = (t, u1 + u2 + u3 + u4 + u5 + u6)
f_brusselator(t, u1, u2, u3, u4, u5, u6) = (t, 0.55 * (u1 + u2 + u3 + u4 + u5 + u6))

## linear model ##########################################################
sol_linmod = solve(prob_pds_linmod, Tsit5());
sol_linmod_MPE = solve(prob_pds_linmod, MPE(), dt = 0.2);
sol_linmod_MPRK = solve(prob_pds_linmod, MPRK22(0.5), dt = 0.2);

# plot
myplot(sol_linmod_MPE, "MPE")
p1 = plot(sol_linmod)
myplot!(sol_linmod_MPE, "MPE")
plot!(sol_linmod_MPE, idxs = (f2, 0, 1, 2))
p2 = plot(sol_linmod)
myplot!(sol_linmod_MPRK, "MPRK")
plot!(sol_linmod_MPRK, idxs = (f2, 0, 1, 2))
plot(p1, p2)

# convergence order
# error based on analytic solution
Expand All @@ -42,76 +49,128 @@ sims = convergence_tab_plot(prob_pds_linmod, [MPE(), Euler()], test_setup;
@assert sims[1].𝒪est[:l∞] > 0.9
#savefig("figs/error_linmod_reference.svg")

sims = convergence_tab_plot(prob_pds_linmod,
[MPRK22(1.0), MPRK22(2 / 3), MPRK22(0.5), Heun()], test_setup;
dts = 0.5 .^ (5:18), order_plot = true);
for i in 1:4
@assert sims[i].𝒪est[:l∞] > 1.9
end
## nonlinear model ########################################################
sol_nonlinmod = solve(prob_pds_nonlinmod, Tsit5());
sol_nonlinmod_MPE = solve(prob_pds_nonlinmod, MPE(), dt = 0.5);
sol_nonlinmod_MPRK = solve(prob_pds_nonlinmod, MPRK22(1.0), dt = 0.5, adaptive = false);

# plot
plot(sol_nonlinmod, legend = :right)
p1 = plot(sol_nonlinmod, legend = :right)
myplot!(sol_nonlinmod_MPE, "MPE")
plot!(sol_nonlinmod_MPE, idxs = (f3, 0, 1, 2, 3))
p2 = plot(sol_nonlinmod, legend = :right)
myplot!(sol_nonlinmod_MPRK, "MPRK")
plot!(sol_nonlinmod_MPRK, idxs = (f3, 0, 1, 2, 3))
plot(p1, p2, layout = (2, 1))

# convergence order
test_setup = Dict(:alg => Vern9(), :reltol => 1e-14, :abstol => 1e-14)
sims = convergence_tab_plot(prob_pds_nonlinmod, [MPE(), Euler()], test_setup;
dts = 0.5 .^ (3:17), order_plot = true);
dts = 0.5 .^ (3:12), order_plot = true);
@assert sims[1].𝒪est[:l∞] > 0.9

sims = convergence_tab_plot(prob_pds_nonlinmod,
[MPRK22(1.0), MPRK22(2 / 3), MPRK22(0.5), Heun()], test_setup;
dts = 0.5 .^ (3:12), order_plot = true);
for i in 1:4
@assert sims[i].𝒪est[:l∞] > 1.9
end
## robertson problem ######################################################
sol_robertson = solve(prob_pds_robertson, Rosenbrock23());
# Cannot use MPE() since adaptive time stepping is not implemented
sol_robertson_MPRK = solve(prob_pds_robertson, MPRK22(1.0));

# plot
plot(sol_robertson[2:end],
idxs = [(0, 1), ((x, y) -> (x, 1e4 .* y), 0, 2), (0, 3)],
color = palette(:default)[1:3]', legend = :right, xaxis = :log)
plot!(sol_robertson[2:end], idxs = (f3, 0, 1, 2, 3), xaxis = :log)
plot!(sol_robertson_MPRK[2:end],
idxs = [(0, 1), ((x, y) -> (x, 1e4 .* y), 0, 2), (0, 3)],
markershape = :circle,
color = palette(:default)[1:3]', legend = :right, xaxis = :log)
plot!(sol_robertson_MPRK[2:end], idxs = (f3, 0, 1, 2, 3), markershape = :circle,
xaxis = :log)

## brusselator problem ####################################################
sol_brusselator = solve(prob_pds_brusselator, Tsit5());
sol_brusselator_MPE = solve(prob_pds_brusselator, MPE(), dt = 0.25);
sol_brusselator_MPRK = solve(prob_pds_brusselator, MPRK22(1.0), dt = 0.25, adaptive = false);

# plot
plot(sol_brusselator, legend = :outerright)
p1 = plot(sol_brusselator, legend = :outerright)
myplot!(sol_brusselator_MPE, "MPE")
plot!(sol_brusselator_MPE, idxs = (f_brusselator, 0, 1, 2, 3, 4, 5, 6),
label = "f_brusselator")
p2 = plot(sol_brusselator, legend = :outerright)
myplot!(sol_brusselator_MPRK, "MPRK")
plot!(sol_brusselator_MPRK, idxs = (f_brusselator, 0, 1, 2, 3, 4, 5, 6),
label = "f_brusselator")
plot(p1, p2, layout = (2, 1))

# convergence order
test_setup = Dict(:alg => Vern9(), :reltol => 1e-14, :abstol => 1e-14)
sims = convergence_tab_plot(prob_pds_brusselator, [MPE()], test_setup; dts = 0.5 .^ (3:17),
sims = convergence_tab_plot(prob_pds_brusselator, [MPE(), Euler()], test_setup;
dts = 0.5 .^ (3:15),
order_plot = true);
@assert sims[1].𝒪est[:l∞] > 0.9

sims = convergence_tab_plot(prob_pds_brusselator,
[MPRK22(1.0), MPRK22(2 / 3), MPRK22(0.5), Heun()],
test_setup; dts = 0.5 .^ (5:15),
order_plot = true);
for i in 1:4
@assert sims[i].𝒪est[:l∞] > 1.9
end
## SIR model ##############################################################
sol_sir = solve(prob_pds_sir, Tsit5());
sol_sir_Euler = solve(prob_pds_sir, Euler(), dt = 0.5);
sol_sir_MPE = solve(prob_pds_sir, MPE(), dt = 0.5);
sol_sir_MPRK = solve(prob_pds_sir, MPRK22(1.0), dt = 0.5, adaptive = false);

# plot
p1 = plot(sol_sir)
myplot!(sol_sir_MPE, "MPE")
plot!(sol_sir_MPE, idxs = (f3, 0, 1, 2, 3), label = "f3")
p2 = plot(sol_sir)
myplot!(sol_sir_Euler, "Euler")
plot!(sol_sir_Euler, idxs = (f3, 0, 1, 2, 3), label = "f3")
plot(p1, p2)
p2 = plot(sol_sir)
myplot!(sol_sir_MPE, "MPE")
plot!(sol_sir_MPE, idxs = (f3, 0, 1, 2, 3), label = "f3")
p3 = plot(sol_sir)
myplot!(sol_sir_MPRK, "MPRK")
plot!(sol_sir_MPRK, idxs = (f3, 0, 1, 2, 3), label = "f3")
plot(p1, p2, p3, layout = (2, 2))

# convergence order
test_setup = Dict(:alg => Vern9(), :reltol => 1e-14, :abstol => 1e-14)
sims = convergence_tab_plot(prob_pds_sir, [MPE(), Euler()], test_setup; dts = 0.5 .^ (1:15),
order_plot = true);
@assert sims[1].𝒪est[:l∞] > 0.9

sims = convergence_tab_plot(prob_pds_sir, [MPRK22(1.0), MPRK22(2 / 3), MPRK22(0.5), Heun()],
test_setup; dts = 0.5 .^ (5:15),
order_plot = true);
for i in 1:4
@assert sims[i].𝒪est[:l∞] > 1.9
end
## bertolazzi problem #####################################################
sol_bertolazzi = solve(prob_pds_bertolazzi, TRBDF2());
sol_bertolazzi_MPE = solve(prob_pds_bertolazzi, MPE(), dt = 0.01);
sol_bertolazzi_MPRK = solve(prob_pds_bertolazzi, MPRK22(1.0), dt = 0.01);

# plot
plot(sol_bertolazzi, legend = :right)
p1 = plot(sol_bertolazzi, legend = :right)
myplot!(sol_bertolazzi_MPE, "MPE")
ylims!((-0.5, 3.5))
plot!(sol_bertolazzi_MPE, idxs = (f3, 0, 1, 2, 3))
p2 = plot(sol_bertolazzi, legend = :right)
myplot!(sol_bertolazzi_MPRK, "MPRK")
ylims!((-0.5, 3.5))
plot!(sol_bertolazzi_MPRK, idxs = (f3, 0, 1, 2, 3))
plot(p1, p2, layout = (2, 1))

# convergence order
test_setup = Dict(:alg => Rosenbrock23(), :reltol => 1e-8, :abstol => 1e-8)
Expand All @@ -121,20 +180,31 @@ convergence_tab_plot(prob_pds_bertolazzi, [MPE(), ImplicitEuler()], test_setup;
### npzd problem ##########################################################
sol_npzd = solve(prob_pds_npzd, Rosenbrock23());
sol_npzd_MPE = solve(prob_pds_npzd, MPE(), dt = 0.1);
sol_npzd_MPRK = solve(prob_pds_npzd, MPRK22(1.0), dt = 0.1, adaptive = false);

# plot
plot(sol_npzd)
p1 = plot(sol_npzd)
myplot!(sol_npzd_MPE, "MPE")
plot!(sol_npzd_MPE, idxs = (f_npzd, 0, 1, 2, 3, 4), label = "f_npzd")
plot!(legend = :bottomright)
p2 = plot(sol_npzd)
myplot!(sol_npzd_MPRK, "MPRK")
plot!(sol_npzd_MPRK, idxs = (f_npzd, 0, 1, 2, 3, 4), label = "f_npzd")
plot!(legend = :bottomright)
plot(p1, p2, layout = (2, 1))

# convergence order
# error should take all time steps into account, not only the final time!
test_setup = Dict(:alg => Rosenbrock23(), :reltol => 1e-14, :abstol => 1e-14)
sims = convergence_tab_plot(prob_pds_npzd, [MPE(), ImplicitEuler()], test_setup;
dts = 0.5 .^ (5:17), order_plot = true);
@assert sims[1].𝒪est[:l∞] > 0.9

sims = convergence_tab_plot(prob_pds_npzd,
[MPRK22(1.0), MPRK22(2 / 3), MPRK22(0.5), Heun()], test_setup;
dts = 0.5 .^ (10:17), order_plot = true);
for i in 1:4
@assert sims[i].𝒪est[:l∞] > 1.9
end
### stratospheric reaction problem ####################################################
sol_stratreac = solve(prob_pds_stratreac, TRBDF2(autodiff = false));
# currently no solver for non-conservative PDS implemented
Expand Down
6 changes: 5 additions & 1 deletion examples/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ function convergence_tab_plot(prob, algs, test_setup = nothing; dts = 0.5 .^ (1:
p = -log2.(err[2:end] ./ err[1:(end - 1)])
#table
algname = string(Base.typename(typeof(algs[i])).wrapper)
if algname == "MPRK22"
my_matches = match(r"(?<= alpha = )([0-9][.][0-9]*)", string(algs[i]))
algname = algname .* "_" .* my_matches.match
end
#algname = string(algs[i])
pretty_table([dts err [NaN; p]]; header = (["dt", "err", "p"]),
title = string("\n\n", string(algs[i])), title_alignment = :c,
Expand Down Expand Up @@ -56,7 +60,7 @@ function _myplot(plotf, sol, name = "", analytic = false)
plotf(sol, color = palette(:default)[1:N]', legend = :right, plot_analytic = false)
end
p = plot!(sol, color = palette(:default)[1:N]', denseplot = false,
markershape = :circle, markerstrokecolor = palette(:default)[1:N]',
markershape = :circle, #markerstrokecolor = palette(:default)[1:N]',
linecolor = invisible(), label = "")
title!(name)
return p
Expand Down
62 changes: 42 additions & 20 deletions src/mprk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,14 @@ This modified Patankar-Runge-Kutta method requires the special structure of a
Applied Numerical Mathematics 182 (2022): 117-147.
[DOI: 10.1016/j.apnum.2022.07.014](https://doi.org/10.1016/j.apnum.2022.07.014)
"""
struct MPRK22{T, Thread} <: OrdinaryDiffEqAdaptiveAlgorithm
struct MPRK22{T, Thread, F} <: OrdinaryDiffEqAdaptiveAlgorithm
alpha::T
thread::Thread
linsolve::F
end

function MPRK22(alpha; thread = False())
MPRK22{typeof(alpha), typeof(thread)}(alpha, thread)
function MPRK22(alpha; thread = False(), linsolve = nothing)
MPRK22{typeof(alpha), typeof(thread), typeof(linsolve)}(alpha, thread, linsolve)
end

OrdinaryDiffEq.alg_order(alg::MPRK22) = 2
Expand Down Expand Up @@ -310,37 +311,58 @@ function initialize!(integrator, cache::MPRK22ConstantCache)
end

function perform_step!(integrator, cache::MPRK22ConstantCache, repeat_step = false)
@unpack t, dt, uprev, f, p = integrator
@unpack a21, b1, b2, c2 = cache
@unpack alg, t, dt, uprev, f, p = integrator
@unpack a21, b1, b2, small_constant = cache

safeguard = floatmin(eltype(uprev))
# Attention: Implementation assumes that the pds is conservative,
# i.e. f.p[i,i] == 0 for all i

uprev .= uprev .+ safeguard
P = f.p(uprev, p, t) # evaluate production matrix
Ptmp = a21 * P

P = f.p(uprev, p, t) # evaluate production terms
D = vec(sum(P, dims = 1)) # sum destruction terms
# avoid division by zero due to zero patankar weights
σ = add_small_constant(uprev, small_constant)

M = -dt * a21 * P ./ reshape(uprev, 1, :) # divide production terms by Patankar-weights
M[diagind(M)] .+= 1.0 .+ dt * a21 * D ./ uprev
u = M \ uprev
# build linear system matrix
M = build_mprk_matrix(Ptmp, σ, dt)

u .= u .+ safeguard
# solve linear system
linprob = LinearProblem(M, uprev)
sol = solve(linprob, alg.linsolve,
alias_A = false, alias_b = false,
assumptions = LinearSolve.OperatorAssumptions(true))
u = sol.u

σ = uprev .* (u ./ uprev) .^ (1 / a21) .+ safeguard
# compute Patankar weight denominator
if a21 == 1.0
σ = u
else
# σ = σ .* (u ./ σ) .^ (1 / a21) # generated Infs when solving brusselator
σ = σ .^ (1 - 1 / a21) .* u .^ (1 / a21)
end
# avoid division by zero due to zero patankar weights
σ = add_small_constant(σ, small_constant)

P2 = f.p(u, p, t + a21 * dt)
D2 = vec(sum(P2, dims = 1))
M .= -dt * (b1 * P + b2 * P2) ./ reshape(σ, 1, :)
M[diagind(M)] .+= 1.0 .+ dt * (b1 * D + b2 * D2) ./ σ
u = M \ uprev
Ptmp = b1 * P + b2 * P2

# build linear system matrix
M = build_mprk_matrix(Ptmp, σ, dt)

u .= u .+ safeguard
# solve linear system
linprob = LinearProblem(M, uprev)
sol = solve(linprob, alg.linsolve,
alias_A = false, alias_b = false,
assumptions = LinearSolve.OperatorAssumptions(true))
u = sol.u

k = f(u, p, t + dt) # For the interpolation, needs k at the updated point
integrator.stats.nf += 1
integrator.fsallast = k

#copied from perform_step for HeunConstantCache
# copied from perform_step for HeunConstantCache
# If a21 = 1.0, then σ is the MPE approximation and thus suited for stiff problems.
# If a21 ≠ 1.0, σ might be a bad choice to estimate errors.
tmp = u - σ
atmp = calculate_residuals(tmp, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
Expand Down

0 comments on commit 325cee0

Please sign in to comment.