Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor time integrator 2N and 3Star more similar to OrdinaryDiffEq.jl integrators #1975

Merged
merged 11 commits into from
Jun 19, 2024
110 changes: 63 additions & 47 deletions src/time_integration/methods_2N.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,8 @@ function Base.getproperty(integrator::SimpleIntegrator2N, field::Symbol)
return getfield(integrator, field)
end

# Fakes `solve`: https://diffeq.sciml.ai/v6.8/basics/overview/#Solving-the-Problems-1
function solve(ode::ODEProblem, alg::T;
dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm2N}
function init(ode::ODEProblem, alg::T;
dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm2N}
warisa-r marked this conversation as resolved.
Show resolved Hide resolved
u = copy(ode.u0)
du = similar(u)
u_tmp = similar(u)
Expand All @@ -129,67 +128,84 @@ function solve(ode::ODEProblem, alg::T;
error("unsupported")
end

return integrator
end

# Fakes `solve`: https://diffeq.sciml.ai/v6.8/basics/overview/#Solving-the-Problems-1
function solve(ode::ODEProblem, alg::T;
dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm2N}
warisa-r marked this conversation as resolved.
Show resolved Hide resolved
integrator = init(ode, alg, dt = dt, callback = callback; kwargs...)

# Start actual solve
solve!(integrator)
end

function solve!(integrator::SimpleIntegrator2N)
@unpack prob = integrator.sol

integrator.finalstep = false

@trixi_timeit timer() "main loop" while !integrator.finalstep
step!(integrator)
end # "main loop" timer

return TimeIntegratorSolution((first(prob.tspan), integrator.t),
(prob.u0, integrator.u),
integrator.sol.prob)
end

function step!(integrator::SimpleIntegrator2N)
ranocha marked this conversation as resolved.
Show resolved Hide resolved
@unpack prob = integrator.sol
@unpack alg = integrator
t_end = last(prob.tspan)
callbacks = integrator.opts.callback

integrator.finalstep = false
@trixi_timeit timer() "main loop" while !integrator.finalstep
if isnan(integrator.dt)
error("time step size `dt` is NaN")
end
@assert !integrator.finalstep
if isnan(integrator.dt)
error("time step size `dt` is NaN")
end

# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end
# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end

# one time step
integrator.u_tmp .= 0
for stage in eachindex(alg.c)
t_stage = integrator.t + integrator.dt * alg.c[stage]
integrator.f(integrator.du, integrator.u, prob.p, t_stage)

a_stage = alg.a[stage]
b_stage_dt = alg.b[stage] * integrator.dt
@trixi_timeit timer() "Runge-Kutta step" begin
@threaded for i in eachindex(integrator.u)
integrator.u_tmp[i] = integrator.du[i] -
integrator.u_tmp[i] * a_stage
integrator.u[i] += integrator.u_tmp[i] * b_stage_dt
end
# one time step
integrator.u_tmp .= 0
for stage in eachindex(alg.c)
t_stage = integrator.t + integrator.dt * alg.c[stage]
integrator.f(integrator.du, integrator.u, prob.p, t_stage)

a_stage = alg.a[stage]
b_stage_dt = alg.b[stage] * integrator.dt
@trixi_timeit timer() "Runge-Kutta step" begin
@threaded for i in eachindex(integrator.u)
integrator.u_tmp[i] = integrator.du[i] -
integrator.u_tmp[i] * a_stage
integrator.u[i] += integrator.u_tmp[i] * b_stage_dt
end
end
integrator.iter += 1
integrator.t += integrator.dt

# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
return nothing
end
integrator.iter += 1
integrator.t += integrator.dt

# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
end

# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
return nothing
end
end

return TimeIntegratorSolution((first(prob.tspan), integrator.t),
(prob.u0, integrator.u),
integrator.sol.prob)
# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
end
end

# get a cache where the RHS can be stored
Expand Down
122 changes: 69 additions & 53 deletions src/time_integration/methods_3Sstar.jl
warisa-r marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,8 @@ function Base.getproperty(integrator::SimpleIntegrator3Sstar, field::Symbol)
return getfield(integrator, field)
end

# Fakes `solve`: https://diffeq.sciml.ai/v6.8/basics/overview/#Solving-the-Problems-1
function solve(ode::ODEProblem, alg::T;
dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm3Sstar}
function init(ode::ODEProblem, alg::T;
dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm3Sstar}
warisa-r marked this conversation as resolved.
Show resolved Hide resolved
u = copy(ode.u0)
du = similar(u)
u_tmp1 = similar(u)
Expand All @@ -199,73 +198,90 @@ function solve(ode::ODEProblem, alg::T;
error("unsupported")
end

return integrator
end

# Fakes `solve`: https://diffeq.sciml.ai/v6.8/basics/overview/#Solving-the-Problems-1
function solve(ode::ODEProblem, alg::T;
dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm3Sstar}
warisa-r marked this conversation as resolved.
Show resolved Hide resolved
integrator = init(ode, alg, dt = dt, callback = callback; kwargs...)

# Start actual solve
solve!(integrator)
end

function solve!(integrator::SimpleIntegrator3Sstar)
@unpack prob = integrator.sol

integrator.finalstep = false

@trixi_timeit timer() "main loop" while !integrator.finalstep
step!(integrator)
end # "main loop" timer

return TimeIntegratorSolution((first(prob.tspan), integrator.t),
(prob.u0, integrator.u),
integrator.sol.prob)
end

function step!(integrator::SimpleIntegrator3Sstar)
@unpack prob = integrator.sol
@unpack alg = integrator
t_end = last(prob.tspan)
callbacks = integrator.opts.callback

integrator.finalstep = false
@trixi_timeit timer() "main loop" while !integrator.finalstep
if isnan(integrator.dt)
error("time step size `dt` is NaN")
end
@assert !integrator.finalstep
if isnan(integrator.dt)
error("time step size `dt` is NaN")
end

# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end
# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end

# one time step
integrator.u_tmp1 .= zero(eltype(integrator.u_tmp1))
integrator.u_tmp2 .= integrator.u
for stage in eachindex(alg.c)
t_stage = integrator.t + integrator.dt * alg.c[stage]
prob.f(integrator.du, integrator.u, prob.p, t_stage)

delta_stage = alg.delta[stage]
gamma1_stage = alg.gamma1[stage]
gamma2_stage = alg.gamma2[stage]
gamma3_stage = alg.gamma3[stage]
beta_stage_dt = alg.beta[stage] * integrator.dt
@trixi_timeit timer() "Runge-Kutta step" begin
@threaded for i in eachindex(integrator.u)
integrator.u_tmp1[i] += delta_stage * integrator.u[i]
integrator.u[i] = (gamma1_stage * integrator.u[i] +
gamma2_stage * integrator.u_tmp1[i] +
gamma3_stage * integrator.u_tmp2[i] +
beta_stage_dt * integrator.du[i])
end
# one time step
integrator.u_tmp1 .= zero(eltype(integrator.u_tmp1))
integrator.u_tmp2 .= integrator.u
for stage in eachindex(alg.c)
t_stage = integrator.t + integrator.dt * alg.c[stage]
prob.f(integrator.du, integrator.u, prob.p, t_stage)

delta_stage = alg.delta[stage]
gamma1_stage = alg.gamma1[stage]
gamma2_stage = alg.gamma2[stage]
gamma3_stage = alg.gamma3[stage]
beta_stage_dt = alg.beta[stage] * integrator.dt
@trixi_timeit timer() "Runge-Kutta step" begin
@threaded for i in eachindex(integrator.u)
integrator.u_tmp1[i] += delta_stage * integrator.u[i]
integrator.u[i] = (gamma1_stage * integrator.u[i] +
gamma2_stage * integrator.u_tmp1[i] +
gamma3_stage * integrator.u_tmp2[i] +
beta_stage_dt * integrator.du[i])
end
end
integrator.iter += 1
integrator.t += integrator.dt

# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
return nothing
end
integrator.iter += 1
integrator.t += integrator.dt

# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
end

# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
return nothing
end
end

return TimeIntegratorSolution((first(prob.tspan), integrator.t),
(prob.u0, integrator.u),
integrator.sol.prob)
# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
end
end

# get a cache where the RHS can be stored
Expand Down
Loading