diff --git a/lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl b/lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl index 870f39e6..4284450f 100644 --- a/lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl +++ b/lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl @@ -16,9 +16,9 @@ using Setfield import BoundaryValueDiffEqCore: BVPJacobianAlgorithm, __extract_problem_details, concrete_jacobian_algorithm, __Fix3, __concrete_nonlinearsolve_algorithm, - __unsafe_nonlinearfunction, BoundaryValueDiffEqAlgorithm, - __sparse_jacobian_cache, __vec, __vec_f, __vec_f!, __vec_bc, - __vec_bc!, __extract_mesh + BoundaryValueDiffEqAlgorithm, __sparse_jacobian_cache, + __vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!, + __extract_mesh import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val diff --git a/lib/BoundaryValueDiffEqAscher/src/ascher.jl b/lib/BoundaryValueDiffEqAscher/src/ascher.jl index 08d40181..e0afe6b1 100644 --- a/lib/BoundaryValueDiffEqAscher/src/ascher.jl +++ b/lib/BoundaryValueDiffEqAscher/src/ascher.jl @@ -327,7 +327,8 @@ function __construct_nlproblem(cache::AscherCache{iip, T}) where {iip, T} @closure (u, p) -> __ascher_mpoint_jacobian(jac_prototype, u, ad, jac_cache, lossₚ) end resid_prototype = zero(lz) - _nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype) + _nlf = NonlinearFunction{iip}( + loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype) nlprob::NonlinearProblem = NonlinearProblem(_nlf, lz, cache.p) return nlprob end diff --git a/lib/BoundaryValueDiffEqCore/src/utils.jl b/lib/BoundaryValueDiffEqCore/src/utils.jl index 8f3029c7..9e8dc781 100644 --- a/lib/BoundaryValueDiffEqCore/src/utils.jl +++ b/lib/BoundaryValueDiffEqCore/src/utils.jl @@ -246,19 +246,6 @@ function __restructure_sol(sol::Vector{<:AbstractArray}, u_size) return map(Base.Fix2(reshape, u_size), sol) end -# Override the checks for NonlinearFunction -struct __unsafe_nonlinearfunction{iip} end - -@inline function __unsafe_nonlinearfunction{iip}( - f::F; jac::J = nothing, jac_prototype::JP = nothing, colorvec::CV = nothing, - resid_prototype::RP = nothing) where {iip, F, J, JP, CV, RP} - return NonlinearFunction{ - iip, SciMLBase.FullSpecialize, F, Nothing, Nothing, Nothing, J, Nothing, Nothing, - JP, Nothing, Nothing, Nothing, Nothing, Nothing, CV, Nothing, RP, Nothing}( - f, nothing, nothing, nothing, jac, nothing, nothing, jac_prototype, nothing, - nothing, nothing, nothing, nothing, colorvec, nothing, resid_prototype, nothing) -end - # Construct the internal NonlinearProblem @inline function __internal_nlsolve_problem( ::BVProblem{uType, tType, iip, nlls}, resid_prototype, diff --git a/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl b/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl index 640580c5..070ffd2b 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl @@ -22,9 +22,8 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit __initial_guess, __maybe_allocate_diffcache, __get_bcresid_prototype, __similar, __vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!, - __unsafe_nonlinearfunction, __internal_nlsolve_problem, - MaybeDiffCache, __extract_mesh, __extract_u0, - __has_initial_guess, __initial_guess_length, + __internal_nlsolve_problem, MaybeDiffCache, __extract_mesh, + __extract_u0, __has_initial_guess, __initial_guess_length, __initial_guess_on_mesh, __flatten_initial_guess, __build_solution, __Fix3, __sparse_jacobian_cache, __sparsity_detection_alg, _sparse_like, ColoredMatrix diff --git a/lib/BoundaryValueDiffEqFIRK/src/firk.jl b/lib/BoundaryValueDiffEqFIRK/src/firk.jl index ca43e5f4..4cc2649b 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/firk.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/firk.jl @@ -498,7 +498,8 @@ function __construct_nlproblem( resid_prototype = vcat(resid_bc, resid_collocation) resid_prototype = vcat(resid_bc, resid_collocation) - nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype) + nlf = NonlinearFunction{iip}( + loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype) return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p) end @@ -547,7 +548,8 @@ function __construct_nlproblem( end resid_prototype = copy(resid) - nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype) + nlf = NonlinearFunction{iip}( + loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype) return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p) end @@ -611,7 +613,8 @@ function __construct_nlproblem( end resid_prototype = vcat(resid_bc, resid_collocation) - nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype) + nlf = NonlinearFunction{iip}( + loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype) return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p) end @@ -650,7 +653,8 @@ function __construct_nlproblem( end resid_prototype = copy(resid) - nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype) + nlf = NonlinearFunction{iip}( + loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype) return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p) end diff --git a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl index 17a16c37..91e1a567 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl @@ -22,12 +22,11 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit __initial_guess, __maybe_allocate_diffcache, __get_bcresid_prototype, __similar, __vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!, - __unsafe_nonlinearfunction, __internal_nlsolve_problem, - __extract_mesh, __extract_u0, __has_initial_guess, - __initial_guess_length, __initial_guess_on_mesh, - __flatten_initial_guess, __build_solution, __Fix3, - __sparse_jacobian_cache, __sparsity_detection_alg, - _sparse_like, ColoredMatrix + __internal_nlsolve_problem, __extract_mesh, __extract_u0, + __has_initial_guess, __initial_guess_length, + __initial_guess_on_mesh, __flatten_initial_guess, + __build_solution, __Fix3, __sparse_jacobian_cache, + __sparsity_detection_alg, _sparse_like, ColoredMatrix import ADTypes: AbstractADType import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scalar_indexing diff --git a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl index 30ca870e..801b5f26 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl @@ -362,7 +362,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo end resid_prototype = vcat(resid_bc, resid_collocation) - nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype) + nlf = NonlinearFunction{iip}( + loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype) return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p) end @@ -441,7 +442,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo end resid_prototype = copy(resid) - nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype) + nlf = NonlinearFunction{iip}( + loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype) return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p) end diff --git a/lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl b/lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl index a359a23f..c3a7b92e 100644 --- a/lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl +++ b/lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl @@ -21,13 +21,12 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit __initial_guess, __maybe_allocate_diffcache, __get_bcresid_prototype, __similar, __vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!, - __unsafe_nonlinearfunction, __internal_nlsolve_problem, - __extract_mesh, __extract_u0, __has_initial_guess, - __initial_guess_length, __initial_guess_on_mesh, - __flatten_initial_guess, __build_solution, __Fix3, - __sparse_jacobian_cache, __sparsity_detection_alg, - _sparse_like, ColoredMatrix, __default_sparse_ad, - __default_nonsparse_ad + __internal_nlsolve_problem, __extract_mesh, __extract_u0, + __has_initial_guess, __initial_guess_length, + __initial_guess_on_mesh, __flatten_initial_guess, + __build_solution, __Fix3, __sparse_jacobian_cache, + __sparsity_detection_alg, _sparse_like, ColoredMatrix, + __default_sparse_ad, __default_nonsparse_ad import ADTypes: AbstractADType import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scalar_indexing diff --git a/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl b/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl index 0b6630af..5b36e55a 100644 --- a/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl +++ b/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl @@ -124,7 +124,8 @@ function __construct_nlproblem(cache::MIRKNCache{iip}, y::AbstractVector) where @closure (u, p) -> __mirkn_mpoint_jacobian(jac_prototype, u, ad, jac_cache, lossₚ) end resid_prototype = zero(lz) - _nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype) + _nlf = NonlinearFunction{iip}( + loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype) nlprob::NonlinearProblem = NonlinearProblem(_nlf, lz, cache.p) return nlprob end diff --git a/lib/BoundaryValueDiffEqShooting/src/BoundaryValueDiffEqShooting.jl b/lib/BoundaryValueDiffEqShooting/src/BoundaryValueDiffEqShooting.jl index 7fce94d4..e9f8f3b0 100644 --- a/lib/BoundaryValueDiffEqShooting/src/BoundaryValueDiffEqShooting.jl +++ b/lib/BoundaryValueDiffEqShooting/src/BoundaryValueDiffEqShooting.jl @@ -23,10 +23,9 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit __maybe_allocate_diffcache, __get_bcresid_prototype, __similar, __vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!, __materialize_jacobian_algorithm, - recursive_flatten_twopoint!, __unsafe_nonlinearfunction, - __internal_nlsolve_problem, NoDiffCacheNeeded, - DiffCacheNeeded, __extract_mesh, __extract_u0, - __has_initial_guess, __initial_guess_length, + recursive_flatten_twopoint!, __internal_nlsolve_problem, + NoDiffCacheNeeded, DiffCacheNeeded, __extract_mesh, + __extract_u0, __has_initial_guess, __initial_guess_length, __initial_guess_on_mesh, __flatten_initial_guess, __get_non_sparse_ad, __build_solution, __Fix3, __sparse_jacobian_cache, __sparsity_detection_alg, diff --git a/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl b/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl index 01117cfc..7c69875e 100644 --- a/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl +++ b/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl @@ -1,4 +1,4 @@ -function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), +function SciMLBase.__solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...) (; f, tspan) = prob @@ -123,8 +123,9 @@ function __solve_nlproblem!( jac_fn = @closure (J, u, p) -> __multiple_shooting_2point_jacobian!( J, u, p, jac_cache, loss_fnₚ, resid_prototype_cached, alg) - loss_function! = __unsafe_nonlinearfunction{true}( - loss_fn; resid_prototype, jac = jac_fn, jac_prototype) + loss_function! = NonlinearFunction{true}( + loss_fn; jac = jac_fn, resid_prototype = resid_prototype, + jac_prototype = jac_prototype) # NOTE: u_at_nodes is updated inplace nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p) @@ -183,8 +184,8 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_ J, u, p, similar(bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache, ode_fn, bc_fn, alg, N, M) - loss_function! = __unsafe_nonlinearfunction{true}( - loss_fn; resid_prototype, jac_prototype, jac = jac_fn) + loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype = resid_prototype, + jac_prototype = jac_prototype, jac = jac_fn) # NOTE: u_at_nodes is updated inplace nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p) diff --git a/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl b/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl index 2e08fcba..231c8633 100644 --- a/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl +++ b/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl @@ -1,4 +1,4 @@ -function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;), +function SciMLBase.__solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;), nlsolve_kwargs = (;), verbose = true, kwargs...) # Setup the problem if prob.u0 isa AbstractArray{<:Number} @@ -70,11 +70,12 @@ function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;), jac_prototype, u, jac_cache, alg.jac_alg.diffmode, loss_fnₚ) end - nlf = __unsafe_nonlinearfunction{iip}( - loss_fn; jac_prototype, resid_prototype, jac = jac_fn) + nlf = NonlinearFunction{iip}(loss_fn; jac_prototype = jac_prototype, + resid_prototype = resid_prototype, jac = jac_fn) nlprob = __internal_nlsolve_problem(prob, resid_prototype, u0, nlf, vec(u0), prob.p) nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, alg.nlsolve) - nlsol = __solve(nlprob, nlsolve_alg; nlsolve_kwargs..., verbose, kwargs...) + nlsol::SciMLBase.NonlinearSolution = __solve( + nlprob, nlsolve_alg; nlsolve_kwargs..., verbose, kwargs...) # There is no way to reinit with the same cache with different cache. But not saving # the internal values gives a significant speedup. So we just create a new cache diff --git a/lib/BoundaryValueDiffEqShooting/test/basic_problems_tests.jl b/lib/BoundaryValueDiffEqShooting/test/basic_problems_tests.jl index f2b1dd5c..2398dedb 100644 --- a/lib/BoundaryValueDiffEqShooting/test/basic_problems_tests.jl +++ b/lib/BoundaryValueDiffEqShooting/test/basic_problems_tests.jl @@ -8,7 +8,8 @@ MultipleShooting( 10, Tsit5(), NewtonRaphson(; autodiff = AutoForwardDiff(; chunksize = 2))), MultipleShooting(10, Tsit5())] - JET_SKIP = [false, false, true, false, false, true] + # JET_SKIP = [false, false, true, false, false, true] + JET_SKIP = [true, true, true, true, true, true] JET_BROKEN = [false, false, false, false, false, false] tspan = (0.0, 100.0) diff --git a/lib/BoundaryValueDiffEqShooting/test/nlls_tests.jl b/lib/BoundaryValueDiffEqShooting/test/nlls_tests.jl index ce606f6e..13ba4dcf 100644 --- a/lib/BoundaryValueDiffEqShooting/test/nlls_tests.jl +++ b/lib/BoundaryValueDiffEqShooting/test/nlls_tests.jl @@ -22,7 +22,8 @@ MultipleShooting( 10, Tsit5(), TrustRegion(; autodiff = AutoForwardDiff(; chunksize = 2))), MultipleShooting(10, Tsit5(), TrustRegion(; autodiff = AutoFiniteDiff()))] - JET_SKIP = fill(false, length(SOLVERS)) + # JET_SKIP = fill(false, length(SOLVERS)) + JET_SKIP = fill(true, length(SOLVERS)) JET_OPT_BROKEN = fill(false, length(SOLVERS)) JET_CALL_BROKEN = fill(false, length(SOLVERS))