Skip to content

Commit

Permalink
Tame the beast of MS runtime dispatches
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 22, 2023
1 parent 2c81ce1 commit 3edc427
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 112 deletions.
10 changes: 6 additions & 4 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.0-rc3"
julia_version = "1.10.0-rc2"
manifest_format = "2.0"
project_hash = "28b32429fc2420e2acf9060b5ab6e19ef75ab1aa"

Expand Down Expand Up @@ -168,9 +168,11 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[deps.DiffEqBase]]
deps = ["ArrayInterface", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces"]
git-tree-sha1 = "83114bb5158ca74ac9dee3d01edf87cc0363b5db"
git-tree-sha1 = "872a3bac710f3f393758c0e4b162f9f1cefabc24"
repo-rev = "ap/nlls_bvp"
repo-url = "https://github.com/SciML/DiffEqBase.jl.git"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
version = "6.144.1"
version = "6.144.2"

[deps.DiffEqBase.extensions]
DiffEqBaseChainRulesCoreExt = "ChainRulesCore"
Expand Down Expand Up @@ -917,7 +919,7 @@ deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[deps.SuiteSparse_jll]]
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"]
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
version = "7.2.1+1"

Expand Down
13 changes: 5 additions & 8 deletions ext/BoundaryValueDiffEqOrdinaryDiffEqExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ end
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))))
end

if load_preference(BoundaryValueDiffEq, "PrecompileMultipleShooting", false)
if load_preference(BoundaryValueDiffEq, "PrecompileMultipleShooting", true)
push!(algs,
MultipleShooting(10, Tsit5();
nlsolve = NewtonRaphson(; autodiff = AutoForwardDiff(chunksize = 2)),
MultipleShooting(10, Tsit5(); nlsolve = NewtonRaphson(),
jac_alg = BVPJacobianAlgorithm(;
bc_diffmode = AutoForwardDiff(; chunksize = 2),
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2))))
Expand Down Expand Up @@ -119,16 +118,14 @@ end
])
end

if load_preference(BoundaryValueDiffEq, "PrecompileMultipleShootingNLLS", false)
if load_preference(BoundaryValueDiffEq, "PrecompileMultipleShootingNLLS", true)
append!(algs,
[
MultipleShooting(10, Tsit5();
nlsolve = TrustRegion(; autodiff = AutoForwardDiff(chunksize = 2)),
MultipleShooting(10, Tsit5(); nlsolve = TrustRegion(),
jac_alg = BVPJacobianAlgorithm(;
bc_diffmode = AutoForwardDiff(; chunksize = 2),
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2))),
MultipleShooting(10, Tsit5();
nlsolve = GaussNewton(; autodiff = AutoForwardDiff(chunksize = 2)),
MultipleShooting(10, Tsit5(); nlsolve = GaussNewton(),
jac_alg = BVPJacobianAlgorithm(;
bc_diffmode = AutoForwardDiff(; chunksize = 2),
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2))),
Expand Down
15 changes: 13 additions & 2 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ end
include("types.jl")
include("utils.jl")
include("algorithms.jl")
include("alg_utils.jl")

include("tableaus/mirk.jl")

Expand All @@ -43,6 +42,18 @@ include("sparse_jacobians.jl")
include("adaptivity.jl")
include("interpolation.jl")

for order in (2, 3, 4, 5, 6)
alg = Symbol("MIRK$(order)")
@eval alg_order(::$(alg)) = $order
@eval alg_stage(::$(alg)) = $(order - 1)
end

SciMLBase.isautodifferentiable(::BoundaryValueDiffEqAlgorithm) = true

Check warning on line 51 in src/BoundaryValueDiffEq.jl

View check run for this annotation

Codecov / codecov/patch

src/BoundaryValueDiffEq.jl#L51

Added line #L51 was not covered by tests
SciMLBase.allows_arbitrary_number_types(::BoundaryValueDiffEqAlgorithm) = true
SciMLBase.allowscomplex(alg::BoundaryValueDiffEqAlgorithm) = true

SciMLBase.isadaptive(alg::AbstractMIRK) = true

function __solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
Expand Down Expand Up @@ -159,7 +170,7 @@ end

export Shooting, MultipleShooting
export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6
export MIRKJacobianComputationAlgorithm, BVPJacobianAlgorithm
export BVPJacobianAlgorithm
# From ODEInterface.jl
export BVPM2, BVPSOL

Expand Down
11 changes: 0 additions & 11 deletions src/alg_utils.jl

This file was deleted.

68 changes: 28 additions & 40 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,7 @@ Single shooting method, reduces BVP to an initial value problem and solves the I
end

function Shooting(; ode_alg = nothing, nlsolve = nothing, jac_alg = nothing)
if jac_alg isa BVPJacobianAlgorithm
_jac_alg = jac_alg
elseif jac_alg === nothing
if nlsolve === nothing
_jac_alg = BVPJacobianAlgorithm()
else
ad = hasfield(typeof(nlsolve), :ad) ? nlsolve.ad : missing
_jac_alg = BVPJacobianAlgorithm(ad)
end
elseif jac_alg isa ADTypes.AbstractADType
_jac_alg = BVPJacobianAlgorithm(jac_alg)
else
throw(ArgumentError("Invalid `jac_alg`: $_jac_alg."))
end
return Shooting(ode_alg, nlsolve, _jac_alg)
return Shooting(ode_alg, nlsolve, __materialize_jacobian_algorithm(nlsolve, jac_alg))
end
@inline Shooting(ode_alg; kwargs...) = Shooting(; ode_alg, kwargs...)
@inline Shooting(ode_alg, nlsolve; kwargs...) = Shooting(; ode_alg, nlsolve, kwargs...)

Check warning on line 36 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L36

Added line #L36 was not covered by tests
Expand All @@ -66,8 +52,11 @@ end
end

"""
MultipleShooting(nshoots::Int, ode_alg = nothing; nlsolve = nothing,
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm())
MultipleShooting(; nshoots::Int, ode_alg = nothing, nlsolve = nothing,
grid_coarsening = true, jac_alg = nothing)
MultipleShooting(nshoots::Int; kwargs...)
MultipleShooting(nshoots::Int, ode_alg; kwargs...)
MultipleShooting(nshoots::Int, ode_alg, nlsolve; kwargs...)
Multiple Shooting method, reduces BVP to an initial value problem and solves the IVP.
Significantly more stable than Single Shooting.
Expand All @@ -78,20 +67,16 @@ Significantly more stable than Single Shooting.
- `ode_alg`: ODE algorithm to use for solving the IVP. Any solver which conforms to the
SciML `ODEProblem` interface can be used! (Defaults to `nothing` which will use
poly-algorithm if `DifferentialEquations.jl` is loaded else this must be supplied)
## Keyword Arguments
- `nlsolve`: Internal Nonlinear solver. Any solver which conforms to the SciML
`NonlinearProblem` interface can be used. Note that any autodiff argument for the solver
will be ignored and a custom jacobian algorithm will be used.
`NonlinearProblem` interface can be used.
- `jac_alg`: Jacobian Algorithm used for the nonlinear solver. Defaults to
`BVPJacobianAlgorithm()`, which automatically decides the best algorithm to use based
on the input types and problem type.
- For `TwoPointBVProblem`, only `diffmode` is used (defaults to
`AutoSparseForwardDiff` if possible else `AutoSparseFiniteDiff`).
- For `BVProblem`, `bc_diffmode` and `nonbc_diffmode` are used. For `nonbc_diffmode`
defaults to `AutoSparseForwardDiff` if possible else `AutoSparseFiniteDiff`. For
`bc_diffmode`, defaults to `AutoForwardDiff` if possible else `AutoFiniteDiff`.
we default to `AutoSparseForwardDiff` if possible else `AutoSparseFiniteDiff`. For
`bc_diffmode`, we default to `AutoForwardDiff` if possible else `AutoFiniteDiff`.
- `grid_coarsening`: Coarsening the multiple-shooting grid to generate a stable IVP
solution. Possible Choices:
- `true`: Halve the grid size, till we reach a grid size of 1.
Expand All @@ -103,10 +88,6 @@ Significantly more stable than Single Shooting.
- `Function`: Takes the current number of shooting points and returns the next number
of shooting points. For example, if `nshoots = 10` and
`grid_coarsening = n -> n ÷ 2`, then the grid will be coarsened to `[5, 2]`.
!!! note
For type-stability, the chunksizes for ForwardDiff ADTypes in `BVPJacobianAlgorithm`
must be provided.
"""
@concrete struct MultipleShooting{J <: BVPJacobianAlgorithm}
ode_alg
Expand All @@ -127,25 +108,30 @@ function update_nshoots(alg::MultipleShooting, nshoots::Int)
alg.grid_coarsening)
end

function MultipleShooting(nshoots::Int, ode_alg = nothing; nlsolve = nothing,
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm())
@assert grid_coarsening isa Bool || grid_coarsening isa Function ||
grid_coarsening isa AbstractVector{<:Integer} ||
grid_coarsening isa NTuple{N, <:Integer} where {N}
function MultipleShooting(; nshoots::Int, ode_alg = nothing, nlsolve = nothing,
grid_coarsening::Union{Bool, Function, <:AbstractVector{<:Integer},
Tuple{Vararg{Integer}}} = true, jac_alg = nothing)
grid_coarsening isa Tuple && (grid_coarsening = Vector(grid_coarsening...))
if grid_coarsening isa AbstractVector
sort!(grid_coarsening; rev = true)
@assert all(grid_coarsening .> 0) && 1 grid_coarsening
end
return MultipleShooting(ode_alg, nlsolve, jac_alg, nshoots, grid_coarsening)
return MultipleShooting(ode_alg, nlsolve,
__materialize_jacobian_algorithm(nlsolve, jac_alg), nshoots, grid_coarsening)
end
@inline MultipleShooting(nshoots::Int; kwargs...) = MultipleShooting(; nshoots, kwargs...)

Check warning on line 122 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L122

Added line #L122 was not covered by tests
@inline MultipleShooting(nshoots::Int, ode_alg; kwargs...) = MultipleShooting(;
nshoots, ode_alg, kwargs...)
@inline MultipleShooting(nshoots::Int, ode_alg, nlsolve; kwargs...) = MultipleShooting(;

Check warning on line 125 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L125

Added line #L125 was not covered by tests
nshoots, ode_alg, nlsolve, kwargs...)

for order in (2, 3, 4, 5, 6)
alg = Symbol("MIRK$(order)")

@eval begin
"""
$($alg)(; nlsolve = NewtonRaphson(), jac_alg = BVPJacobianAlgorithm())
$($alg)(; nlsolve = NewtonRaphson(), jac_alg = BVPJacobianAlgorithm(),
max_num_subintervals = 3000, defect_threshold = 0.1)
$($order)th order Monotonic Implicit Runge Kutta method.
Expand All @@ -163,10 +149,10 @@ for order in (2, 3, 4, 5, 6)
`nonbc_diffmode` defaults to `AutoSparseForwardDiff` if possible else
`AutoSparseFiniteDiff`. For `bc_diffmode`, defaults to `AutoForwardDiff` if
possible else `AutoFiniteDiff`.
!!! note
For type-stability, the chunksizes for ForwardDiff ADTypes in
`BVPJacobianAlgorithm` must be provided.
- `max_num_subintervals`: Maximum number of subintervals to use for the
adaptive stepsize control.
- `defect_threshold`: Threshold for the defect control. If the defect is larger
than this threshold, the step is rejected and the stepsize is reduced.
## References
Expand All @@ -181,9 +167,11 @@ for order in (2, 3, 4, 5, 6)
}
```
"""
Base.@kwdef struct $(alg){N, J <: BVPJacobianAlgorithm} <: AbstractMIRK
Base.@kwdef struct $(alg){N, J <: BVPJacobianAlgorithm, T} <: AbstractMIRK
nlsolve::N = nothing
jac_alg::J = BVPJacobianAlgorithm()
max_num_subintervals::Int = 3000
defect_threshold::T = 0.1
end
end
end
Expand Down
57 changes: 27 additions & 30 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...)
@unpack f, tspan = prob
(; f, tspan) = prob

@assert (ensemblealg isa EnsembleSerial)||(ensemblealg isa EnsembleThreads) "Currently MultipleShooting only supports `EnsembleSerial` and `EnsembleThreads`!"
if !(ensemblealg isa EnsembleSerial) && !(ensemblealg isa EnsembleThreads)
throw(ArgumentError("Currently MultipleShooting only supports `EnsembleSerial` and \

Check warning on line 6 in src/solve/multiple_shooting.jl

View check run for this annotation

Codecov / codecov/patch

src/solve/multiple_shooting.jl#L6

Added line #L6 was not covered by tests
`EnsembleThreads`!"))
end

ig, T, N, Nig, u0 = __extract_problem_details(prob; dt = 0.1)
has_initial_guess = _unwrap_val(ig)
Expand Down Expand Up @@ -30,11 +33,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),

internal_ode_kwargs = (; verbose, kwargs..., odesolve_kwargs..., save_end = true)

function solve_internal_odes!(resid_nodes::T1, us::T2, p::T3, cur_nshoot::Int,
nodes::T4, odecache::C) where {T1, T2, T3, T4, C}
return __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoot,
odecache, nodes, u0_size, N, ensemblealg)
end
solve_internal_odes! = @closure (resid_nodes, us, p, cur_nshoot, nodes, odecache) -> __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoot,
odecache, nodes, u0_size, N, ensemblealg, tspan)

# This gets all the nshoots except the final SingleShooting case
all_nshoots = __get_all_nshoots(alg.grid_coarsening, nshoots)
Expand Down Expand Up @@ -96,7 +96,7 @@ function __solve_nlproblem!(::TwoPointBVProblem, alg::MultipleShooting, bcresid_
resid_prototype = vcat(bcresid_prototype[1],
similar(u_at_nodes, cur_nshoot * N), bcresid_prototype[2])

loss_fn = (du, u, p) -> __multiple_shooting_2point_loss!(du, u, p, cur_nshoot,
loss_fn = @closure (du, u, p) -> __multiple_shooting_2point_loss!(du, u, p, cur_nshoot,
nodes, prob, solve_internal_odes!, resida_len, residb_len, N, bca, bcb,
ode_cache_loss_fn)

Expand All @@ -112,19 +112,18 @@ function __solve_nlproblem!(::TwoPointBVProblem, alg::MultipleShooting, bcresid_
jac_cache, alg.jac_alg.diffmode, alg.ode_alg, cur_nshoot, u0;
internal_ode_kwargs...)

loss_fnₚ = (du, u) -> __multiple_shooting_2point_loss!(du, u, prob.p, cur_nshoot,
nodes, prob, solve_internal_odes!, resida_len, residb_len, N, bca, bcb,
loss_fnₚ = @closure (du, u) -> __multiple_shooting_2point_loss!(du, u, prob.p,
cur_nshoot, nodes, prob, solve_internal_odes!, resida_len, residb_len, N, bca, bcb,
ode_cache_jac_fn)

jac_fn = (J, u, p) -> __multiple_shooting_2point_jacobian!(J, u, p, jac_cache,
jac_fn = @closure (J, u, p) -> __multiple_shooting_2point_jacobian!(J, u, p, jac_cache,
loss_fnₚ, resid_prototype_cached, alg)

loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn,
jac_prototype)
loss_function! = __unsafe_nonlinearfunction{true}(loss_fn; resid_prototype,
jac = jac_fn, jac_prototype)

# NOTE: u_at_nodes is updated inplace
nlprob = (M != N ? NonlinearLeastSquaresProblem : NonlinearProblem)(loss_function!,
u_at_nodes, prob.p)
nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p)
__solve(nlprob, alg.nlsolve; kwargs..., alias_u0 = true)

return nothing
Expand All @@ -144,7 +143,7 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
resid_nodes = __maybe_allocate_diffcache(__resid_nodes,
pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.bc_diffmode)

loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss!(du, u, p, cur_nshoot,
loss_fn = @closure (du, u, p) -> __multiple_shooting_mpoint_loss!(du, u, p, cur_nshoot,
nodes, prob, solve_internal_odes!, resid_len, N, f, bc, u0_size, prob.tspan,
alg.ode_alg, u0, ode_cache_loss_fn)

Expand All @@ -169,22 +168,21 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
jac_prototype = vcat(init_jacobian(bc_jac_cache), init_jacobian(ode_jac_cache))

# Define the functions now
ode_fn = (du, u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes,
ode_fn = @closure (du, u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes,
ode_cache_ode_jac_fn)
bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc!(du, u, prob.p, cur_nshoot, nodes,
prob, solve_internal_odes!, N, f, bc, u0_size, prob.tspan, alg.ode_alg, u0,
ode_cache_bc_jac_fn)
bc_fn = @closure (du, u) -> __multiple_shooting_mpoint_loss_bc!(du, u, prob.p,
cur_nshoot, nodes, prob, solve_internal_odes!, N, f, bc, u0_size, prob.tspan,
alg.ode_alg, u0, ode_cache_bc_jac_fn)

jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian!(J, u, p,
jac_fn = @closure (J, u, p) -> __multiple_shooting_mpoint_jacobian!(J, u, p,
similar(bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache,
ode_fn, bc_fn, alg, N, M)

loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn,
jac_prototype)
loss_function! = __unsafe_nonlinearfunction{true}(loss_fn; resid_prototype,
jac_prototype, jac = jac_fn)

# NOTE: u_at_nodes is updated inplace
nlprob = (M != N ? NonlinearLeastSquaresProblem : NonlinearProblem)(loss_function!,
u_at_nodes, prob.p)
nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p)
__solve(nlprob, alg.nlsolve; kwargs..., alias_u0 = true)

return nothing
Expand Down Expand Up @@ -224,7 +222,7 @@ end

# Not using `EnsembleProblem` since it is hard to initialize the cache and stuff
function __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoots::Int,
odecache, nodes, u0_size, N::Int, ::EnsembleSerial)
odecache, nodes, u0_size, N::Int, ::EnsembleSerial, tspan)
ts_ = Vector{Vector{typeof(first(tspan))}}(undef, cur_nshoots)
us_ = Vector{Vector{typeof(us)}}(undef, cur_nshoots)

Expand All @@ -242,7 +240,7 @@ function __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoots::
end

function __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoots::Int,
odecache::Vector, nodes, u0_size, N::Int, ::EnsembleThreads)
odecache::Vector, nodes, u0_size, N::Int, ::EnsembleThreads, tspan)
ts_ = Vector{Vector{typeof(first(tspan))}}(undef, cur_nshoots)
us_ = Vector{Vector{typeof(us)}}(undef, cur_nshoots)

Expand Down Expand Up @@ -399,14 +397,13 @@ end
sol = solve!(odecache)

if SciMLBase.successful_retcode(sol)
res = sol(nodes).u
for i in 1:length(nodes)
u_at_nodes[(i - 1) * N .+ (1:N)] .= vec(res[i])
u_at_nodes[(i - 1) * N .+ (1:N)] .= vec(sol(nodes[i]))
end
else
@warn "Initialization using odesolve failed. Initializing using 0s. It is \
recommended to provide an initial guess function via \
`u0 = <function>(p, t)` or `u0 = <function>(t)` in this case."
`u0 = <function>(p, t)` in this case."
fill!(u_at_nodes, 0)
end

Expand Down
Loading

0 comments on commit 3edc427

Please sign in to comment.