Skip to content

Commit

Permalink
Add JET Tests and disable Multiple Shooting tests for now
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 1, 2024
1 parent 2b0acab commit 1eeecc6
Show file tree
Hide file tree
Showing 11 changed files with 492 additions and 496 deletions.
16 changes: 9 additions & 7 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

julia_version = "1.10.0-rc2"
manifest_format = "2.0"
project_hash = "caa3a84a29ae467d2583d2084e7f50dd07a4148f"
project_hash = "0ac16cb78a3540d2a0e82de32da74fda24f340a4"

[[deps.ADTypes]]
git-tree-sha1 = "332e5d7baeff8497b923b730b994fa480601efc7"
git-tree-sha1 = "41c37aa88889c171f1300ceac1313c06e891d245"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
version = "0.2.5"
version = "0.2.6"

[[deps.Adapt]]
deps = ["LinearAlgebra", "Requires"]
Expand Down Expand Up @@ -587,10 +587,10 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"

[[deps.NonlinearSolve]]
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "EnumX", "FastBroadcast", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArrays", "UnPack"]
git-tree-sha1 = "e61a283ef89110668b83db24cd7de8868fb8133e"
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "EnumX", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArrays", "UnPack"]
git-tree-sha1 = "72b036b728461272ae1b1c3f7096cb4c319d8793"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
version = "3.3.0"
version = "3.4.0"

[deps.NonlinearSolve.extensions]
NonlinearSolveBandedMatricesExt = "BandedMatrices"
Expand All @@ -599,6 +599,7 @@ version = "3.3.0"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
NonlinearSolveMINPACKExt = "MINPACK"
NonlinearSolveNLsolveExt = "NLsolve"
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
NonlinearSolveSpeedMappingExt = "SpeedMapping"
NonlinearSolveSymbolicsExt = "Symbolics"
NonlinearSolveZygoteExt = "Zygote"
Expand All @@ -610,6 +611,7 @@ version = "3.3.0"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down Expand Up @@ -775,7 +777,7 @@ version = "0.6.42"

[[deps.SciMLBase]]
deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FillArrays", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces"]
git-tree-sha1 = "50f2e6905c201e1212871824da13a802804835d8"
git-tree-sha1 = "b8f7a0807314cce87bf846ba5fd12c1b0ef512b7"
repo-rev = "ap/nlls_bvp"
repo-url = "https://github.com/SciML/SciMLBase.jl.git"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ ArrayInterface = "7"
BandedMatrices = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.145"
DiffEqDevTools = "2.44"
FastAlmostBandedMatrices = "0.1"
FastClosures = "0.3"
ForwardDiff = "0.10"
JET = "0.8"
LinearAlgebra = "1.9"
LinearSolve = "2.20"
NonlinearSolve = "2.6.1, 3"
Expand Down Expand Up @@ -74,6 +76,7 @@ julia = "1.9"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -83,4 +86,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["StaticArrays", "Random", "DiffEqDevTools", "OrdinaryDiffEq", "Test", "SafeTestsets", "ODEInterface", "Aqua", "LinearSolve", "RecursiveArrayTools"]
test = ["StaticArrays", "Random", "DiffEqDevTools", "OrdinaryDiffEq", "Test", "SafeTestsets", "ODEInterface", "Aqua", "LinearSolve", "RecursiveArrayTools", "JET"]
81 changes: 46 additions & 35 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
# Algorithms
abstract type BoundaryValueDiffEqAlgorithm <: SciMLBase.AbstractBVPAlgorithm end
abstract type AbstractShooting <: BoundaryValueDiffEqAlgorithm end
abstract type AbstractMIRK <: BoundaryValueDiffEqAlgorithm end

## Disable the ugly verbose printing by default
@inline __modifier_text!(list, fieldname, field) = push!(list, "$fieldname = $(field)")
@inline __modifier_text!(list, fieldname, ::Nothing) = list
@inline __modifier_text!(list, fieldname, ::Missing) = list
@inline function __modifier_text!(list, fieldname, field::SciMLBase.AbstractODEAlgorithm)
push!(list, "$fieldname = $(__nameof(field))()")
end

function Base.show(io::IO, alg::BoundaryValueDiffEqAlgorithm)
print(io, "$(nameof(typeof(alg)))()")
print(io, "$(__nameof(alg))(")
modifiers = String[]
for field in fieldnames(typeof(alg))
__modifier_text!(modifiers, field, getfield(alg, field))
end
print(io, join(modifiers, ", "))
print(io, ")")
end

"""
Expand All @@ -28,7 +42,7 @@ Single shooting method, reduces BVP to an initial value problem and solves the I
and problem type. If `BVPJacobianAlgorithm` is provided, only `diffmode` is used
(defaults to `AutoForwardDiff` if possible else `AutoFiniteDiff`).
"""
@concrete struct Shooting{J <: BVPJacobianAlgorithm} <: BoundaryValueDiffEqAlgorithm
@concrete struct Shooting{J <: BVPJacobianAlgorithm} <: AbstractShooting
ode_alg
nlsolve
jac_alg::J
Expand All @@ -40,51 +54,39 @@ end
@inline Shooting(ode_alg; kwargs...) = Shooting(; ode_alg, kwargs...)
@inline Shooting(ode_alg, nlsolve; kwargs...) = Shooting(; ode_alg, nlsolve, kwargs...)

function Base.show(io::IO, alg::Shooting)
print(io, "Shooting(")
modifiers = String[]
alg.nlsolve !== nothing && push!(modifiers, "nlsolve = $(alg.nlsolve)")
alg.jac_alg !== nothing && push!(modifiers, "jac_alg = $(alg.jac_alg)")
alg.ode_alg !== nothing && push!(modifiers, "ode_alg = $(__nameof(alg.ode_alg))()")
print(io, join(modifiers, ", "))
print(io, ")")
end

@inline function concretize_jacobian_algorithm(alg::Shooting, prob)
alg.jac_alg.diffmode === nothing &&
(return @set alg.jac_alg.diffmode = __default_nonsparse_ad(prob.u0))
return alg
end

"""
MultipleShooting(nshoots::Int, ode_alg = nothing; nlsolve = nothing,
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm())
MultipleShooting(; nshoots::Int, ode_alg = nothing, nlsolve = nothing,
grid_coarsening = true, jac_alg = nothing)
MultipleShooting(nshoots::Int; kwargs...)
MultipleShooting(nshoots::Int, ode_alg; kwargs...)
MultipleShooting(nshoots::Int, ode_alg, nlsolve; kwargs...)
Multiple Shooting method, reduces BVP to an initial value problem and solves the IVP.
Significantly more stable than Single Shooting.
## Arguments
- `nshoots`: Number of shooting points.
- `ode_alg`: ODE algorithm to use for solving the IVP. Any solver which conforms to the
SciML `ODEProblem` interface can be used! (Defaults to `nothing` which will use
poly-algorithm if `DifferentialEquations.jl` is loaded else this must be supplied)
## Keyword Arguments
- `nlsolve`: Internal Nonlinear solver. Any solver which conforms to the SciML
`NonlinearProblem` interface can be used. Note that any autodiff argument for the solver
will be ignored and a custom jacobian algorithm will be used.
`NonlinearProblem` interface can be used.
- `jac_alg`: Jacobian Algorithm used for the nonlinear solver. Defaults to
`BVPJacobianAlgorithm()`, which automatically decides the best algorithm to use based
on the input types and problem type.
+ For `TwoPointBVProblem`, only `diffmode` is used (defaults to
`AutoSparseForwardDiff` if possible else `AutoSparseFiniteDiff`).
+ For `BVProblem`, `bc_diffmode` and `nonbc_diffmode` are used. For `nonbc_diffmode`
defaults to `AutoSparseForwardDiff` if possible else `AutoSparseFiniteDiff`. For
`bc_diffmode`, defaults to `AutoForwardDiff` if possible else `AutoFiniteDiff`.
we default to `AutoSparseForwardDiff` if possible else `AutoSparseFiniteDiff`. For
`bc_diffmode`, we default to `AutoForwardDiff` if possible else `AutoFiniteDiff`.
- `grid_coarsening`: Coarsening the multiple-shooting grid to generate a stable IVP
solution. Possible Choices:
Expand All @@ -97,20 +99,25 @@ Significantly more stable than Single Shooting.
+ `Function`: Takes the current number of shooting points and returns the next number
of shooting points. For example, if `nshoots = 10` and
`grid_coarsening = n -> n ÷ 2`, then the grid will be coarsened to `[5, 2]`.
!!! note
For type-stability, the chunksizes for ForwardDiff ADTypes in `BVPJacobianAlgorithm`
must be provided.
"""
@concrete struct MultipleShooting{J <: BVPJacobianAlgorithm}
@concrete struct MultipleShooting{J <: BVPJacobianAlgorithm} <: AbstractShooting
ode_alg
nlsolve
jac_alg::J
nshoots::Int
grid_coarsening
end

# function Base.show(io::IO, alg::MultipleShooting)
# print(io, "MultipleShooting(")
# modifiers = String[]
# alg.nlsolve !== nothing && push!(modifiers, "nlsolve = $(alg.nlsolve)")
# alg.jac_alg !== nothing && push!(modifiers, "jac_alg = $(alg.jac_alg)")
# alg.ode_alg !== nothing && push!(modifiers, "ode_alg = $(__nameof(alg.ode_alg))()")
# print(io, join(modifiers, ", "))
# print(io, ")")
# end

function concretize_jacobian_algorithm(alg::MultipleShooting, prob)
jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg)
return MultipleShooting(alg.ode_alg, alg.nlsolve, jac_alg, alg.nshoots,
Expand All @@ -122,18 +129,22 @@ function update_nshoots(alg::MultipleShooting, nshoots::Int)
alg.grid_coarsening)
end

function MultipleShooting(nshoots::Int, ode_alg = nothing; nlsolve = nothing,
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm())
@assert grid_coarsening isa Bool || grid_coarsening isa Function ||
grid_coarsening isa AbstractVector{<:Integer} ||
grid_coarsening isa NTuple{N, <:Integer} where {N}
function MultipleShooting(; nshoots::Int, ode_alg = nothing, nlsolve = nothing,
grid_coarsening::Union{Bool, Function, <:AbstractVector{<:Integer},
Tuple{Vararg{Integer}}} = true, jac_alg = nothing)
grid_coarsening isa Tuple && (grid_coarsening = Vector(grid_coarsening...))
if grid_coarsening isa AbstractVector
sort!(grid_coarsening; rev = true)
@assert all(grid_coarsening .> 0) && 1 grid_coarsening
end
return MultipleShooting(ode_alg, nlsolve, jac_alg, nshoots, grid_coarsening)
return MultipleShooting(ode_alg, nlsolve,
__materialize_jacobian_algorithm(nlsolve, jac_alg), nshoots, grid_coarsening)
end
@inline MultipleShooting(nshoots::Int; kwargs...) = MultipleShooting(; nshoots, kwargs...)
@inline MultipleShooting(nshoots::Int, ode_alg; kwargs...) = MultipleShooting(;
nshoots, ode_alg, kwargs...)
@inline MultipleShooting(nshoots::Int, ode_alg, nlsolve; kwargs...) = MultipleShooting(;
nshoots, ode_alg, nlsolve, kwargs...)

for order in (2, 3, 4, 5, 6)
alg = Symbol("MIRK$(order)")
Expand Down
9 changes: 4 additions & 5 deletions src/solve/single_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,15 @@ function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;),
nlf = __unsafe_nonlinearfunction{iip}(loss_fn; jac_prototype, resid_prototype,
jac = jac_fn)
nlprob = __internal_nlsolve_problem(prob, resid_prototype, u0, nlf, vec(u0), prob.p)
opt = __solve(nlprob, alg.nlsolve; nlsolve_kwargs..., verbose, kwargs...)
nlsol = __solve(nlprob, alg.nlsolve; nlsolve_kwargs..., verbose, kwargs...)

# There is no way to reinit with the same cache with different cache. But not saving
# the internal values gives a significant speedup. So we just create a new cache
internal_prob_final = ODEProblem{iip}(prob.f, reshape(opt.u, u0_size), prob.tspan,
internal_prob_final = ODEProblem{iip}(prob.f, reshape(nlsol.u, u0_size), prob.tspan,
prob.p)
sol = __solve(internal_prob_final, alg.ode_alg; actual_ode_kwargs...)
odesol = __solve(internal_prob_final, alg.ode_alg; actual_ode_kwargs...)

retcode = SciMLBase.successful_retcode(opt) ? sol.retcode : opt.retcode
return BVPSolution(sol; original = opt, retcode)
return SciMLBase.build_solution(prob, odesol, nlsol)
end

function __single_shooting_loss!(resid_, u0_, p, cache, bc::BC, u0_size,
Expand Down
100 changes: 0 additions & 100 deletions test/misc/odeinterface.jl

This file was deleted.

Loading

0 comments on commit 1eeecc6

Please sign in to comment.