diff --git a/Project.toml b/Project.toml index c53c17b..c27d749 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "1.10.0" +version = "1.10.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index eba6d99..66993dd 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -1,32 +1,31 @@ module SimpleNonlinearSolve -using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations - -@recompile_invalidations begin - using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff, - AutoPolyesterForwardDiff - using ArrayInterface: ArrayInterface - using ConcreteStructs: @concrete - using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode, - AbstractSafeNonlinearTerminationMode, - AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode, - NONLINEARSOLVE_DEFAULT_NORM - using DifferentiationInterface: DifferentiationInterface - using DiffResults: DiffResults - using FastClosures: @closure - using FiniteDiff: FiniteDiff - using ForwardDiff: ForwardDiff, Dual - using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess, lu, - mul!, norm, transpose - using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex - using Reexport: @reexport - using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, - NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem, - ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm, - build_solution, isinplace, _unwrap_val - using Setfield: @set! - using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size -end +using PrecompileTools: @compile_workload, @setup_workload + +using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff, + AutoPolyesterForwardDiff +using ArrayInterface: ArrayInterface +using ConcreteStructs: @concrete +using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode, + AbstractSafeNonlinearTerminationMode, + AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode, + NONLINEARSOLVE_DEFAULT_NORM +using DifferentiationInterface: DifferentiationInterface +using DiffResults: DiffResults +using FastClosures: @closure +using FiniteDiff: FiniteDiff +using ForwardDiff: ForwardDiff, Dual +using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess, lu, mul!, + norm, transpose +using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex +using Reexport: @reexport +using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, + AbstractNonlinearFunction, StandardNonlinearProblem, + NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem, + ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm, + build_solution, isinplace, _unwrap_val, warn_paramtype +using Setfield: @set! +using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size const DI = DifferentiationInterface @@ -37,7 +36,7 @@ abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorit abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end @inline __is_extension_loaded(::Val) = false - +include("immutable_nonlinear_problem.jl") include("utils.jl") include("linesearch.jl") @@ -69,9 +68,21 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...; return solve(prob, ITP(), args...; prob.kwargs..., kwargs...) end -# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms +# Bypass the highlevel checks for NonlinearProblem for Simple Algorithms function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) + prob = convert(ImmutableNonlinearProblem, prob) + if sensealg === nothing && haskey(prob.kwargs, :sensealg) + sensealg = prob.kwargs[:sensealg] + end + new_u0 = u0 !== nothing ? u0 : prob.u0 + new_p = p !== nothing ? p : prob.p + return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p, + p === nothing, alg, args...; prob.kwargs..., kwargs...) +end + +function SciMLBase.solve(prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, + args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] end @@ -81,7 +92,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSol p === nothing, alg, args...; prob.kwargs..., kwargs...) end -function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed, +function __internal_solve_up(_prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob return SciMLBase.__solve(prob, alg, args...; kwargs...) diff --git a/src/immutable_nonlinear_problem.jl b/src/immutable_nonlinear_problem.jl new file mode 100644 index 0000000..856a8df --- /dev/null +++ b/src/immutable_nonlinear_problem.jl @@ -0,0 +1,63 @@ +struct ImmutableNonlinearProblem{uType, isinplace, P, F, K, PT} <: + AbstractNonlinearProblem{uType, isinplace} + f::F + u0::uType + p::P + problem_type::PT + kwargs::K + @add_kwonly function ImmutableNonlinearProblem{iip}(f::AbstractNonlinearFunction{iip}, u0, + p = NullParameters(), + problem_type = StandardNonlinearProblem(); + kwargs...) where {iip} + if haskey(kwargs, :p) + error("`p` specified as a keyword argument `p = $(kwargs[:p])` to `NonlinearProblem`. This is not supported.") + end + warn_paramtype(p) + new{typeof(u0), iip, typeof(p), typeof(f), + typeof(kwargs), typeof(problem_type)}(f, + u0, + p, + problem_type, + kwargs) + end + + """ + + Define a steady state problem using the given function. + `isinplace` optionally sets whether the function is inplace or not. + This is determined automatically, but not inferred. + """ + function ImmutableNonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip} + ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...) + end +end + +""" + +Define a nonlinear problem using an instance of +[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction). +""" +function ImmutableNonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...) + ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...) +end + +function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...) + ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...) +end + +""" + +Define a ImmutableNonlinearProblem problem from SteadyStateProblem +""" +function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem) + ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p) +end + + +function Base.convert(::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem} + ImmutableNonlinearProblem{isinplace(prob)}(prob.f, + prob.u0, + prob.p, + prob.problem_type; + prob.kwargs...) +end diff --git a/src/nlsolve/broyden.jl b/src/nlsolve/broyden.jl index 890578f..2b02955 100644 --- a/src/nlsolve/broyden.jl +++ b/src/nlsolve/broyden.jl @@ -22,7 +22,7 @@ end __get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS) -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...; +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/dfsane.jl b/src/nlsolve/dfsane.jl index 835ee4b..7460666 100644 --- a/src/nlsolve/dfsane.jl +++ b/src/nlsolve/dfsane.jl @@ -54,7 +54,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy) end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...; +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane{M}, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) where {M} x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/halley.jl b/src/nlsolve/halley.jl index d3fe1cd..bf2f5f8 100644 --- a/src/nlsolve/halley.jl +++ b/src/nlsolve/halley.jl @@ -24,7 +24,7 @@ A low-overhead implementation of Halley's Method. autodiff = nothing end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...; +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/klement.jl b/src/nlsolve/klement.jl index 8041ef4..2db3405 100644 --- a/src/nlsolve/klement.jl +++ b/src/nlsolve/klement.jl @@ -6,7 +6,7 @@ method is non-allocating on scalar and static array problems. """ struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...; +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/lbroyden.jl b/src/nlsolve/lbroyden.jl index 1eeda4f..683e9e5 100644 --- a/src/nlsolve/lbroyden.jl +++ b/src/nlsolve/lbroyden.jl @@ -29,7 +29,7 @@ function SimpleLimitedMemoryBroyden(; return SimpleLimitedMemoryBroyden{_unwrap_val(threshold), _unwrap_val(linesearch)}(alpha) end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden, +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden, args...; termination_condition = nothing, kwargs...) if prob.u0 isa SArray if termination_condition === nothing || @@ -44,7 +44,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyd return __generic_solve(prob, alg, args...; termination_condition, kwargs...) end -@views function __generic_solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden, +@views function __generic_solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) @@ -114,7 +114,7 @@ end # Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite # finicky, so we'll implement it separately from the generic version # Ignore termination_condition. Don't pass things into internal functions -function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden, +function __static_solve(prob::ImmutableNonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden, args...; abstol = nothing, maxiters = 1000, kwargs...) x = prob.u0 fx = _get_fx(prob, x) diff --git a/src/nlsolve/raphson.jl b/src/nlsolve/raphson.jl index ca6864b..d7e72db 100644 --- a/src/nlsolve/raphson.jl +++ b/src/nlsolve/raphson.jl @@ -23,7 +23,7 @@ end const SimpleGaussNewton = SimpleNewtonRaphson -function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, +function SciMLBase.__solve(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing, maxiters = 1000, termination_condition = nothing, alias_u0 = false, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/trustRegion.jl b/src/nlsolve/trustRegion.jl index 6ff263e..2e78baf 100644 --- a/src/nlsolve/trustRegion.jl +++ b/src/nlsolve/trustRegion.jl @@ -55,7 +55,7 @@ scalar and static array problems. nlsolve_update_rule = Val(false) end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args...; +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegion, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/utils.jl b/src/utils.jl index 2e76a4b..d8343a9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -123,7 +123,7 @@ end error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype`") return _get_fx(prob.f, x, prob.p) end -@inline _get_fx(prob::NonlinearProblem, x) = _get_fx(prob.f, x, prob.p) +@inline _get_fx(prob::ImmutableNonlinearProblem, x) = _get_fx(prob.f, x, prob.p) @inline function _get_fx(f::NonlinearFunction, x, p) if isinplace(f) if f.resid_prototype !== nothing @@ -145,7 +145,7 @@ end # different. NonlinearSolve is more for robust / cached solvers while SimpleNonlinearSolve # is meant for low overhead solvers, users can opt into the other termination modes but the # default is to use the least overhead version. -function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing) +function init_termination_cache(prob::ImmutableNonlinearProblem, abstol, reltol, du, u, ::Nothing) return init_termination_cache( prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix1(maximum, abs))) end @@ -155,7 +155,7 @@ function init_termination_cache( prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix2(norm, 2))) end -function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, +function init_termination_cache(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode) T = promote_type(eltype(du), eltype(u)) abstol = __get_tolerance(u, abstol, T)