Skip to content

Commit

Permalink
Remove unsafe_nonlinearfunction
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikQQY committed Nov 9, 2024
1 parent c2bca3e commit 646fdc9
Show file tree
Hide file tree
Showing 14 changed files with 50 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ using Setfield
import BoundaryValueDiffEqCore: BVPJacobianAlgorithm, __extract_problem_details,
concrete_jacobian_algorithm, __Fix3,
__concrete_nonlinearsolve_algorithm,
__unsafe_nonlinearfunction, BoundaryValueDiffEqAlgorithm,
__sparse_jacobian_cache, __vec, __vec_f, __vec_f!, __vec_bc,
__vec_bc!, __extract_mesh
BoundaryValueDiffEqAlgorithm, __sparse_jacobian_cache,
__vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!,
__extract_mesh

import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val

Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqAscher/src/ascher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ function __construct_nlproblem(cache::AscherCache{iip, T}) where {iip, T}
@closure (u, p) -> __ascher_mpoint_jacobian(jac_prototype, u, ad, jac_cache, lossₚ)
end
resid_prototype = zero(lz)
_nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype)
_nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
nlprob::NonlinearProblem = NonlinearProblem(_nlf, lz, cache.p)
return nlprob
end
Expand Down
13 changes: 0 additions & 13 deletions lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,19 +246,6 @@ function __restructure_sol(sol::Vector{<:AbstractArray}, u_size)
return map(Base.Fix2(reshape, u_size), sol)
end

# Override the checks for NonlinearFunction
struct __unsafe_nonlinearfunction{iip} end

@inline function __unsafe_nonlinearfunction{iip}(
f::F; jac::J = nothing, jac_prototype::JP = nothing, colorvec::CV = nothing,
resid_prototype::RP = nothing) where {iip, F, J, JP, CV, RP}
return NonlinearFunction{
iip, SciMLBase.FullSpecialize, F, Nothing, Nothing, Nothing, J, Nothing, Nothing,
JP, Nothing, Nothing, Nothing, Nothing, Nothing, CV, Nothing, RP, Nothing}(
f, nothing, nothing, nothing, jac, nothing, nothing, jac_prototype, nothing,
nothing, nothing, nothing, nothing, colorvec, nothing, resid_prototype, nothing)
end

# Construct the internal NonlinearProblem
@inline function __internal_nlsolve_problem(
::BVProblem{uType, tType, iip, nlls}, resid_prototype,
Expand Down
5 changes: 2 additions & 3 deletions lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit
__initial_guess, __maybe_allocate_diffcache,
__get_bcresid_prototype, __similar, __vec, __vec_f,
__vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!,
__unsafe_nonlinearfunction, __internal_nlsolve_problem,
MaybeDiffCache, __extract_mesh, __extract_u0,
__has_initial_guess, __initial_guess_length,
__internal_nlsolve_problem, MaybeDiffCache, __extract_mesh,
__extract_u0, __has_initial_guess, __initial_guess_length,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __sparse_jacobian_cache,
__sparsity_detection_alg, _sparse_like, ColoredMatrix
Expand Down
12 changes: 8 additions & 4 deletions lib/BoundaryValueDiffEqFIRK/src/firk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ function __construct_nlproblem(
resid_prototype = vcat(resid_bc, resid_collocation)

resid_prototype = vcat(resid_bc, resid_collocation)
nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype)
nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)

return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p)
end
Expand Down Expand Up @@ -547,7 +548,8 @@ function __construct_nlproblem(
end

resid_prototype = copy(resid)
nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype)
nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p)
end

Expand Down Expand Up @@ -611,7 +613,8 @@ function __construct_nlproblem(
end

resid_prototype = vcat(resid_bc, resid_collocation)
nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype)
nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)

return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p)
end
Expand Down Expand Up @@ -650,7 +653,8 @@ function __construct_nlproblem(
end

resid_prototype = copy(resid)
nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype)
nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p)
end

Expand Down
11 changes: 5 additions & 6 deletions lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit
__initial_guess, __maybe_allocate_diffcache,
__get_bcresid_prototype, __similar, __vec, __vec_f,
__vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!,
__unsafe_nonlinearfunction, __internal_nlsolve_problem,
__extract_mesh, __extract_u0, __has_initial_guess,
__initial_guess_length, __initial_guess_on_mesh,
__flatten_initial_guess, __build_solution, __Fix3,
__sparse_jacobian_cache, __sparsity_detection_alg,
_sparse_like, ColoredMatrix
__internal_nlsolve_problem, __extract_mesh, __extract_u0,
__has_initial_guess, __initial_guess_length,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __sparse_jacobian_cache,
__sparsity_detection_alg, _sparse_like, ColoredMatrix

import ADTypes: AbstractADType
import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scalar_indexing
Expand Down
6 changes: 4 additions & 2 deletions lib/BoundaryValueDiffEqMIRK/src/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
end

resid_prototype = vcat(resid_bc, resid_collocation)
nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype)
nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)

return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p)
end
Expand Down Expand Up @@ -441,7 +442,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
end

resid_prototype = copy(resid)
nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype)
nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p)
end

Expand Down
13 changes: 6 additions & 7 deletions lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit
__initial_guess, __maybe_allocate_diffcache,
__get_bcresid_prototype, __similar, __vec, __vec_f,
__vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!,
__unsafe_nonlinearfunction, __internal_nlsolve_problem,
__extract_mesh, __extract_u0, __has_initial_guess,
__initial_guess_length, __initial_guess_on_mesh,
__flatten_initial_guess, __build_solution, __Fix3,
__sparse_jacobian_cache, __sparsity_detection_alg,
_sparse_like, ColoredMatrix, __default_sparse_ad,
__default_nonsparse_ad
__internal_nlsolve_problem, __extract_mesh, __extract_u0,
__has_initial_guess, __initial_guess_length,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __sparse_jacobian_cache,
__sparsity_detection_alg, _sparse_like, ColoredMatrix,
__default_sparse_ad, __default_nonsparse_ad

import ADTypes: AbstractADType
import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scalar_indexing
Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ function __construct_nlproblem(cache::MIRKNCache{iip}, y::AbstractVector) where
@closure (u, p) -> __mirkn_mpoint_jacobian(jac_prototype, u, ad, jac_cache, lossₚ)
end
resid_prototype = zero(lz)
_nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype)
_nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
nlprob::NonlinearProblem = NonlinearProblem(_nlf, lz, cache.p)
return nlprob
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit
__maybe_allocate_diffcache, __get_bcresid_prototype,
__similar, __vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!,
__materialize_jacobian_algorithm,
recursive_flatten_twopoint!, __unsafe_nonlinearfunction,
__internal_nlsolve_problem, NoDiffCacheNeeded,
DiffCacheNeeded, __extract_mesh, __extract_u0,
__has_initial_guess, __initial_guess_length,
recursive_flatten_twopoint!, __internal_nlsolve_problem,
NoDiffCacheNeeded, DiffCacheNeeded, __extract_mesh,
__extract_u0, __has_initial_guess, __initial_guess_length,
__initial_guess_on_mesh, __flatten_initial_guess,
__get_non_sparse_ad, __build_solution, __Fix3,
__sparse_jacobian_cache, __sparsity_detection_alg,
Expand Down
11 changes: 6 additions & 5 deletions lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
function SciMLBase.__solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...)
(; f, tspan) = prob

Expand Down Expand Up @@ -123,8 +123,9 @@ function __solve_nlproblem!(
jac_fn = @closure (J, u, p) -> __multiple_shooting_2point_jacobian!(
J, u, p, jac_cache, loss_fnₚ, resid_prototype_cached, alg)

loss_function! = __unsafe_nonlinearfunction{true}(
loss_fn; resid_prototype, jac = jac_fn, jac_prototype)
loss_function! = NonlinearFunction{true}(
loss_fn; jac = jac_fn, resid_prototype = resid_prototype,
jac_prototype = jac_prototype)

# NOTE: u_at_nodes is updated inplace
nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p)
Expand Down Expand Up @@ -183,8 +184,8 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
J, u, p, similar(bcresid_prototype), resid_nodes,
ode_jac_cache, bc_jac_cache, ode_fn, bc_fn, alg, N, M)

loss_function! = __unsafe_nonlinearfunction{true}(
loss_fn; resid_prototype, jac_prototype, jac = jac_fn)
loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype = resid_prototype,
jac_prototype = jac_prototype, jac = jac_fn)

# NOTE: u_at_nodes is updated inplace
nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p)
Expand Down
9 changes: 5 additions & 4 deletions lib/BoundaryValueDiffEqShooting/src/single_shooting.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;),
function SciMLBase.__solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;),
nlsolve_kwargs = (;), verbose = true, kwargs...)
# Setup the problem
if prob.u0 isa AbstractArray{<:Number}
Expand Down Expand Up @@ -70,11 +70,12 @@ function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;),
jac_prototype, u, jac_cache, alg.jac_alg.diffmode, loss_fnₚ)
end

nlf = __unsafe_nonlinearfunction{iip}(
loss_fn; jac_prototype, resid_prototype, jac = jac_fn)
nlf = NonlinearFunction{iip}(loss_fn; jac_prototype = jac_prototype,
resid_prototype = resid_prototype, jac = jac_fn)
nlprob = __internal_nlsolve_problem(prob, resid_prototype, u0, nlf, vec(u0), prob.p)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, alg.nlsolve)
nlsol = __solve(nlprob, nlsolve_alg; nlsolve_kwargs..., verbose, kwargs...)
nlsol::SciMLBase.NonlinearSolution = __solve(
nlprob, nlsolve_alg; nlsolve_kwargs..., verbose, kwargs...)

# There is no way to reinit with the same cache with different cache. But not saving
# the internal values gives a significant speedup. So we just create a new cache
Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqShooting/test/basic_problems_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
MultipleShooting(
10, Tsit5(), NewtonRaphson(; autodiff = AutoForwardDiff(; chunksize = 2))),
MultipleShooting(10, Tsit5())]
JET_SKIP = [false, false, true, false, false, true]
# JET_SKIP = [false, false, true, false, false, true]
JET_SKIP = [true, true, true, true, true, true]
JET_BROKEN = [false, false, false, false, false, false]

tspan = (0.0, 100.0)
Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqShooting/test/nlls_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
MultipleShooting(
10, Tsit5(), TrustRegion(; autodiff = AutoForwardDiff(; chunksize = 2))),
MultipleShooting(10, Tsit5(), TrustRegion(; autodiff = AutoFiniteDiff()))]
JET_SKIP = fill(false, length(SOLVERS))
# JET_SKIP = fill(false, length(SOLVERS))
JET_SKIP = fill(true, length(SOLVERS))
JET_OPT_BROKEN = fill(false, length(SOLVERS))
JET_CALL_BROKEN = fill(false, length(SOLVERS))

Expand Down

0 comments on commit 646fdc9

Please sign in to comment.