Skip to content

Commit

Permalink
Run JuliaFormatter.format()
Browse files Browse the repository at this point in the history
                Using JuliaFormatter v1.0.47
  • Loading branch information
LilithHafner committed Feb 14, 2024
1 parent 4be662e commit e2ea47b
Show file tree
Hide file tree
Showing 15 changed files with 128 additions and 83 deletions.
30 changes: 20 additions & 10 deletions benchmark/simple_pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,42 @@ function create_simple_pendulum_benchmark()
suite["OOP"] = oop_suite

if @isdefined(MultipleShooting)
iip_suite["MultipleShooting(100, Tsit5; grid_coarsening = true)"] = @benchmarkable solve($SimplePendulumBenchmark.prob_iip,
iip_suite["MultipleShooting(100, Tsit5; grid_coarsening = true)"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_iip,
$MultipleShooting(100, Tsit5()))
iip_suite["MultipleShooting(100, Tsit5; grid_coarsening = false)"] = @benchmarkable solve($SimplePendulumBenchmark.prob_iip,
iip_suite["MultipleShooting(100, Tsit5; grid_coarsening = false)"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_iip,
$MultipleShooting(100, Tsit5(); grid_coarsening = false))
iip_suite["MultipleShooting(10, Tsit5; grid_coarsening = true)"] = @benchmarkable solve($SimplePendulumBenchmark.prob_iip,
iip_suite["MultipleShooting(10, Tsit5; grid_coarsening = true)"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_iip,
$MultipleShooting(10, Tsit5()))
iip_suite["MultipleShooting(10, Tsit5; grid_coarsening = false)"] = @benchmarkable solve($SimplePendulumBenchmark.prob_iip,
iip_suite["MultipleShooting(10, Tsit5; grid_coarsening = false)"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_iip,
$MultipleShooting(10, Tsit5(); grid_coarsening = false))
end
if @isdefined(Shooting)
iip_suite["Shooting(Tsit5())"] = @benchmarkable solve($SimplePendulumBenchmark.prob_iip,
iip_suite["Shooting(Tsit5())"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_iip,
$Shooting(Tsit5()))
end

if @isdefined(MultipleShooting)
oop_suite["MultipleShooting(100, Tsit5; grid_coarsening = true)"] = @benchmarkable solve($SimplePendulumBenchmark.prob_oop,
oop_suite["MultipleShooting(100, Tsit5; grid_coarsening = true)"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_oop,
$MultipleShooting(100, Tsit5()))
oop_suite["MultipleShooting(100, Tsit5; grid_coarsening = false)"] = @benchmarkable solve($SimplePendulumBenchmark.prob_oop,
oop_suite["MultipleShooting(100, Tsit5; grid_coarsening = false)"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_oop,
$MultipleShooting(100, Tsit5(); grid_coarsening = false))
oop_suite["MultipleShooting(10, Tsit5; grid_coarsening = true)"] = @benchmarkable solve($SimplePendulumBenchmark.prob_oop,
oop_suite["MultipleShooting(10, Tsit5; grid_coarsening = true)"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_oop,
$MultipleShooting(10, Tsit5()))
oop_suite["MultipleShooting(10, Tsit5; grid_coarsening = false)"] = @benchmarkable solve($SimplePendulumBenchmark.prob_oop,
oop_suite["MultipleShooting(10, Tsit5; grid_coarsening = false)"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_oop,
$MultipleShooting(10, Tsit5(); grid_coarsening = false))
end
if @isdefined(Shooting)
oop_suite["Shooting(Tsit5())"] = @benchmarkable solve($SimplePendulumBenchmark.prob_oop,
oop_suite["Shooting(Tsit5())"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_oop,
$Shooting(Tsit5()))
end

Expand Down
48 changes: 28 additions & 20 deletions ext/BoundaryValueDiffEqODEInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@ module BoundaryValueDiffEqODEInterfaceExt
using SciMLBase, BoundaryValueDiffEq, ODEInterface
import SciMLBase: __solve
import ODEInterface: OptionsODE, OPT_ATOL, OPT_RTOL, OPT_METHODCHOICE, OPT_DIAGNOSTICOUTPUT,
OPT_ERRORCONTROL, OPT_SINGULARTERM, OPT_MAXSTEPS, OPT_BVPCLASS, OPT_SOLMETHOD,
OPT_RHS_CALLMODE, OPT_COLLOCATIONPTS, OPT_MAXSUBINTERVALS, RHS_CALL_INSITU, evalSolution
OPT_ERRORCONTROL, OPT_SINGULARTERM, OPT_MAXSTEPS, OPT_BVPCLASS,
OPT_SOLMETHOD,
OPT_RHS_CALLMODE, OPT_COLLOCATIONPTS, OPT_MAXSUBINTERVALS,
RHS_CALL_INSITU, evalSolution
import ODEInterface: Bvpm2, bvpm2_init, bvpm2_solve, bvpm2_destroy, bvpm2_get_x
import ODEInterface: bvpsol
import ODEInterface: colnew

import ForwardDiff

function _test_bvpm2_bvpsol_colnew_problem_criteria(_, ::SciMLBase.StandardBVProblem, alg::Symbol)
function _test_bvpm2_bvpsol_colnew_problem_criteria(

Check warning on line 16 in ext/BoundaryValueDiffEqODEInterfaceExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BoundaryValueDiffEqODEInterfaceExt.jl#L16

Added line #L16 was not covered by tests
_, ::SciMLBase.StandardBVProblem, alg::Symbol)
throw(ArgumentError("$(alg) does not support standard BVProblem. Only TwoPointBVProblem is supported."))
end
function _test_bvpm2_bvpsol_colnew_problem_criteria(prob, ::TwoPointBVProblem, alg::Symbol)
Expand Down Expand Up @@ -54,7 +57,8 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3,

sol, retcode, stats = bvpm2_solve(initial_guess, bvp2m_f, bvp2m_bc, opt)
retcode = retcode 0 ? ReturnCode.Success : ReturnCode.Failure
destats = SciMLBase.DEStats(stats["no_rhs_calls"], 0, 0, 0, stats["no_jac_calls"], 0, 0, 0, 0, 0, 0)
destats = SciMLBase.DEStats(
stats["no_rhs_calls"], 0, 0, 0, stats["no_jac_calls"], 0, 0, 0, 0, 0, 0)

x_mesh = bvpm2_get_x(sol)
evalsol = evalSolution(sol, x_mesh)
Expand All @@ -72,7 +76,7 @@ end
#-------
function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3,
dt = 0.0, verbose = true, kwargs...)
_test_bvpm2_bvpsol_colnew_problem_criteria(prob, prob.problem_type, :BVPSOL)
_test_bvpm2_bvpsol_colnew_problem_criteria(prob, prob.problem_type, :BVPSOL)
@assert isa(prob.p, SciMLBase.NullParameters) "BVPSOL only supports NullParameters!"
@assert isa(prob.u0, AbstractVector{<:AbstractArray}) "BVPSOL requires a vector of initial guesses!"
n, u0 = (length(prob.u0) - 1), reduce(hcat, prob.u0)
Expand Down Expand Up @@ -125,7 +129,8 @@ end
#-------
# COLNEW
#-------
function SciMLBase.__solve(prob::BVProblem, alg::COLNEW; maxiters = 1000, reltol=1e-4, dt = 0.0, verbose = true, kwargs...)
function SciMLBase.__solve(prob::BVProblem, alg::COLNEW; maxiters = 1000,
reltol = 1e-4, dt = 0.0, verbose = true, kwargs...)
_test_bvpm2_bvpsol_colnew_problem_criteria(prob, prob.problem_type, :COLNEW)
has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray}
dt 0 && throw(ArgumentError("dt must be positive"))
Expand All @@ -136,18 +141,20 @@ function SciMLBase.__solve(prob::BVProblem, alg::COLNEW; maxiters = 1000, reltol
end
T = eltype(u0)
mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))
opt = OptionsODE(OPT_BVPCLASS => alg.bvpclass, OPT_COLLOCATIONPTS => alg.collocationpts,
opt = OptionsODE(
OPT_BVPCLASS => alg.bvpclass, OPT_COLLOCATIONPTS => alg.collocationpts,
OPT_MAXSTEPS => maxiters, OPT_DIAGNOSTICOUTPUT => alg.diagnostic_output,
OPT_MAXSUBINTERVALS => alg.max_num_subintervals, OPT_RTOL => reltol)
OPT_MAXSUBINTERVALS => alg.max_num_subintervals, OPT_RTOL => reltol)
orders = ones(Int, no_odes)
_tspan = [prob.tspan[1], prob.tspan[2]]
iip = SciMLBase.isinplace(prob)

rhs(t, u, du) = if iip
prob.f(du, u, prob.p, t)
else
(du .= prob.f(u, prob.p, t))
end
rhs(t, u, du) =
if iip
prob.f(du, u, prob.p, t)
else
(du .= prob.f(u, prob.p, t))

Check warning on line 156 in ext/BoundaryValueDiffEqODEInterfaceExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BoundaryValueDiffEqODEInterfaceExt.jl#L156

Added line #L156 was not covered by tests
end

if prob.f.jac === nothing
if iip
Expand All @@ -169,38 +176,38 @@ function SciMLBase.__solve(prob::BVProblem, alg::COLNEW; maxiters = 1000, reltol
end
Drhs(t, u, df) = jac(df, u, prob.p, t)


#TODO: Fix bc and bcjac for multi-points BVP

n_bc_a = length(first(prob.f.bcresid_prototype.x))
n_bc_b = length(last(prob.f.bcresid_prototype.x))
zeta = vcat(fill(first(prob.tspan), n_bc_a), fill(last(prob.tspan), n_bc_b))
bc = function (i, z, resid)
tmpa = copy(z); tmpb = copy(z)
tmpa = copy(z)
tmpb = copy(z)
tmp_resid_a = zeros(T, n_bc_a)
tmp_resid_b = zeros(T, n_bc_b)
prob.f.bc[1](tmp_resid_a, tmpa, prob.p)
prob.f.bc[2](tmp_resid_b, tmpb, prob.p)

for j=1:n_bc_a
for j in 1:n_bc_a
if i == j
resid[1] = tmp_resid_a[j]
end
end
for j=1:n_bc_b
for j in 1:n_bc_b
if i == (j + n_bc_a)
resid[1] = tmp_resid_b[j]
end
end
end

Dbc = function (i, z, dbc)
for j=1:n_bc_a
for j in 1:n_bc_a
if i == j
dbc[i] = 1.0
end
end
for j=1:n_bc_b
for j in 1:n_bc_b
if i == (j + n_bc_a)
dbc[i] = 1.0
end
Expand All @@ -222,7 +229,8 @@ function SciMLBase.__solve(prob::BVProblem, alg::COLNEW; maxiters = 1000, reltol
end

evalsol = evalSolution(sol, mesh)
destats = SciMLBase.DEStats(stats["no_rhs_calls"], 0, 0, 0, stats["no_jac_calls"], 0, 0, 0, 0, 0, 0)
destats = SciMLBase.DEStats(
stats["no_rhs_calls"], 0, 0, 0, stats["no_jac_calls"], 0, 0, 0, 0, 0, 0)

return DiffEqBase.build_solution(prob, alg, mesh,
collect(Vector{eltype(evalsol)}, eachrow(evalsol));
Expand Down
8 changes: 4 additions & 4 deletions ext/BoundaryValueDiffEqOrdinaryDiffEqExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
BVProblem(f1!, bc1!, u0, tspan),
BVProblem(f1, bc1, u0, tspan),
TwoPointBVProblem(f1!, (bc1_a!, bc1_b!), u0, tspan; bcresid_prototype),
TwoPointBVProblem(f1, (bc1_a, bc1_b), u0, tspan; bcresid_prototype),
TwoPointBVProblem(f1, (bc1_a, bc1_b), u0, tspan; bcresid_prototype)
]

algs = []
Expand Down Expand Up @@ -101,7 +101,7 @@ end
TwoPointBVProblem(f1_nlls!, (bc1_nlls_a!, bc1_nlls_b!), u0, tspan;
bcresid_prototype = bcresid_prototype2),
TwoPointBVProblem(f1_nlls, (bc1_nlls_a, bc1_nlls_b), u0, tspan;
bcresid_prototype = bcresid_prototype2),
bcresid_prototype = bcresid_prototype2)
]

algs = []
Expand All @@ -112,7 +112,7 @@ end
Shooting(Tsit5(); nlsolve = LevenbergMarquardt(),
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))),
Shooting(Tsit5(); nlsolve = GaussNewton(),
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))),
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2)))
])
end

Expand All @@ -129,7 +129,7 @@ end
nlsolve = GaussNewton(; autodiff = AutoForwardDiff(chunksize = 2)),
jac_alg = BVPJacobianAlgorithm(;
bc_diffmode = AutoForwardDiff(; chunksize = 2),
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2))),
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2)))
])
end

Expand Down
13 changes: 7 additions & 6 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat

@recompile_invalidations begin
using ADTypes, Adapt, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve,
PreallocationTools, Preferences, RecursiveArrayTools, Reexport, SciMLBase, Setfield,
SparseDiffTools, Tricks
PreallocationTools, Preferences, RecursiveArrayTools, Reexport, SciMLBase,
Setfield,
SparseDiffTools, Tricks

# Special Matrix Types
using BandedMatrices, FastAlmostBandedMatrices, SparseArrays

import ADTypes: AbstractADType
import ArrayInterface: matrix_colors,
parameterless_type, undefmatrix, fast_scalar_indexing
parameterless_type, undefmatrix, fast_scalar_indexing
import ConcreteStructs: @concrete
import DiffEqBase: solve
import ForwardDiff: pickchunksize
Expand Down Expand Up @@ -74,7 +75,7 @@ end
BVProblem(f1!, bc1!, u0, tspan),
BVProblem(f1, bc1, u0, tspan),
TwoPointBVProblem(f1!, (bc1_a!, bc1_b!), u0, tspan; bcresid_prototype),
TwoPointBVProblem(f1, (bc1_a, bc1_b), u0, tspan; bcresid_prototype),
TwoPointBVProblem(f1, (bc1_a, bc1_b), u0, tspan; bcresid_prototype)
]

algs = []
Expand Down Expand Up @@ -129,7 +130,7 @@ end
TwoPointBVProblem(f1_nlls!, (bc1_nlls_a!, bc1_nlls_b!), u0, tspan;
bcresid_prototype = bcresid_prototype2),
TwoPointBVProblem(f1_nlls, (bc1_nlls_a, bc1_nlls_b), u0, tspan;
bcresid_prototype = bcresid_prototype2),
bcresid_prototype = bcresid_prototype2)
]

jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))
Expand All @@ -144,7 +145,7 @@ end
[
MIRK2(; jac_alg, nlsolve), MIRK3(; jac_alg, nlsolve),
MIRK4(; jac_alg, nlsolve), MIRK5(; jac_alg, nlsolve),
MIRK6(; jac_alg, nlsolve),
MIRK6(; jac_alg, nlsolve)
])
end
end
Expand Down
8 changes: 5 additions & 3 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ function __propagate_nlsolve_ad_to_jac_alg(nlsolve::N) where {N}
ad = hasfield(N, :ad) ? nlsolve.ad : nothing
ad === nothing && return BVPJacobianAlgorithm()

Base.depwarn("Setting autodiff to the nonlinear solver in Shooting has been deprecated \
and will have no effect from the next major release. Update to use \
`BVPJacobianAlgorithm` directly", :Shooting)
Base.depwarn(

Check warning on line 58 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L58

Added line #L58 was not covered by tests
"Setting autodiff to the nonlinear solver in Shooting has been deprecated \
and will have no effect from the next major release. Update to use \
`BVPJacobianAlgorithm` directly",
:Shooting)
return BVPJacobianAlgorithm(ad)
end

Expand Down
28 changes: 14 additions & 14 deletions src/mirk_tableaus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function constructMIRK3(::Type{T}) where {T}
v = [0, 4 // 9]
b = [1 // 4, 3 // 4]
x = [0 0
2//9 0]
2//9 0]

# Interpolant tableau
s_star = 3
Expand All @@ -52,8 +52,8 @@ function constructMIRK4(::Type{T}) where {T}
v = [0, 1, 1 // 2, 27 // 32]
b = [1 // 6, 1 // 6, 2 // 3, 0]
x = [0 0 0 0
0 0 0 0
1//8 -1//8 0 0]
0 0 0 0
1//8 -1//8 0 0]

# Interpolant tableau
s_star = 4
Expand All @@ -74,16 +74,16 @@ function constructMIRK5(::Type{T}) where {T}
v = [0, 1, 27 // 32, 837 // 1250]
b = [5 // 54, 1 // 14, 32 // 81, 250 // 567]
x = [0 0 0 0
0 0 0 0
3//64 -9//64 0 0
21//1000 63//5000 -252//625 0]
0 0 0 0
3//64 -9//64 0 0
21//1000 63//5000 -252//625 0]

# Interpolant tableau
s_star = 6
c_star = [4 // 5, 13 // 23]
v_star = [4 // 5, 13 // 23]
x_star = [14//1125 -74//875 -128//3375 104//945 0 0
1//2 4508233//1958887 48720832//2518569 -27646420//17629983 -11517095//559682 0]
1//2 4508233//1958887 48720832//2518569 -27646420//17629983 -11517095//559682 0]
τ_star = 0.3

TU = MIRKTableau(s, T.(c), T.(v), T.(b), T.(x))
Expand All @@ -98,19 +98,19 @@ function constructMIRK6(::Type{T}) where {T}
v = [0, 1, 5 // 32, 27 // 32, 1 // 2]
b = [7 // 90, 7 // 90, 16 // 45, 16 // 45, 2 // 15, 0, 0, 0, 0]
x = [0 0 0 0 0
0 0 0 0 0
9//64 -3//64 0 0 0
3//64 -9//64 0 0 0
-5//24 5//24 2//3 -2//3 0]
0 0 0 0 0
9//64 -3//64 0 0 0
3//64 -9//64 0 0 0
-5//24 5//24 2//3 -2//3 0]

# Interpolant tableau
s_star = 9
c_star = [7 // 16, 3 // 8, 9 // 16, 1 // 8]
v_star = [7 // 16, 3 // 8, 9 // 16, 1 // 8]
x_star = [1547//32768 -1225//32768 749//4096 -287//2048 -861//16384 0 0 0 0
83//1536 -13//384 283//1536 -167//1536 -49//512 0 0 0 0
1225//32768 -1547//32768 287//2048 -749//4096 861//16384 0 0 0 0
233//3456 -19//1152 0 0 0 -5//72 7//72 -17//216 0]
83//1536 -13//384 283//1536 -167//1536 -49//512 0 0 0 0
1225//32768 -1547//32768 287//2048 -749//4096 861//16384 0 0 0 0
233//3456 -19//1152 0 0 0 -5//72 7//72 -17//216 0]
τ_star = 0.7156

TU = MIRKTableau(s, T.(c), T.(v), T.(b), T.(x))
Expand Down
6 changes: 4 additions & 2 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual,
return nothing
end

function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual,
function __mirk_loss!(
resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual,
mesh, cache) where {BC1, BC2}
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
Expand Down Expand Up @@ -347,7 +348,8 @@ function __mirk_mpoint_jacobian!(J, _, x, bc_diffmode, nonbc_diffmode, bc_diffca
return nothing
end

function __mirk_mpoint_jacobian!(J::AlmostBandedMatrix, J_c, x, bc_diffmode, nonbc_diffmode,
function __mirk_mpoint_jacobian!(
J::AlmostBandedMatrix, J_c, x, bc_diffmode, nonbc_diffmode,
bc_diffcache, nonbc_diffcache, loss_bc::BC, loss_collocation::C, resid_bc,
resid_collocation, L::Int) where {BC, C}
J_bc = fillpart(J)
Expand Down
Loading

0 comments on commit e2ea47b

Please sign in to comment.