From 865a0f82d71250290efbec5d4a56a042cc744278 Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Fri, 7 Jun 2024 18:07:26 -0400 Subject: [PATCH 01/11] convert to immutable for CUDA tests --- src/SimpleNonlinearSolve.jl | 8 ++-- src/immutable_nonlinear_problem.jl | 68 ++++++++++++++++++++++++++++++ test/gpu/cuda_tests.jl | 1 + 3 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 src/immutable_nonlinear_problem.jl diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index eba6d99..fd8c0a5 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -20,12 +20,13 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati mul!, norm, transpose using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex using Reexport: @reexport - using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, + using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, + AbstractNonlinearFunction, StandardNonlinearProblem, NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm, - build_solution, isinplace, _unwrap_val + build_solution, isinplace, _unwrap_val, warn_paramtype using Setfield: @set! - using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size + using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size, SizedVector, SizedMatrix end const DI = DifferentiationInterface @@ -60,6 +61,7 @@ include("bracketing/itp.jl") # AD include("ad.jl") +include("immutable_nonlinear_problem.jl") ## Default algorithm diff --git a/src/immutable_nonlinear_problem.jl b/src/immutable_nonlinear_problem.jl new file mode 100644 index 0000000..0ac98ef --- /dev/null +++ b/src/immutable_nonlinear_problem.jl @@ -0,0 +1,68 @@ +struct ImmutableNonlinearProblem{uType, isinplace, P, F, K, PT} <: + AbstractNonlinearProblem{uType, isinplace} + f::F + u0::uType + p::P + problem_type::PT + kwargs::K + @add_kwonly function ImmutableNonlinearProblem{iip}(f::AbstractNonlinearFunction{iip}, u0, + p = NullParameters(), + problem_type = StandardNonlinearProblem(); + kwargs...) where {iip} + if haskey(kwargs, :p) + error("`p` specified as a keyword argument `p = $(kwargs[:p])` to `NonlinearProblem`. This is not supported.") + end + warn_paramtype(p) + new{typeof(u0), iip, typeof(p), typeof(f), + typeof(kwargs), typeof(problem_type)}(f, + u0, + p, + problem_type, + kwargs) + end + + """ + + Define a steady state problem using the given function. + `isinplace` optionally sets whether the function is inplace or not. + This is determined automatically, but not inferred. + """ + function ImmutableNonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip} + ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...) + end +end + +""" + +Define a nonlinear problem using an instance of +[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction). +""" +function ImmutableNonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...) + ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...) +end + +function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...) + ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...) +end + +""" + +Define a ImmutableNonlinearProblem problem from SteadyStateProblem +""" +function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem) + ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p) +end + +staticarray_itize(x) = x +staticarray_itize(x::Vector) = SVector{length(x)}(x) +staticarray_itize(x::SizedVector) = SVector{length(x)}(x) +staticarray_itize(x::Matrix) = SMatrix{size(x)...}(x) +staticarray_itize(x::SizedMatrix) = SMatrix{size(x)...}(x) + +function Base.convert(::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem} + ImmutableNonlinearProblem(prob.f, + staticarray_itize(prob.u0), + staticarray_itize(prob.p), + prob.problem_type; + prob.kwargs...) +end diff --git a/test/gpu/cuda_tests.jl b/test/gpu/cuda_tests.jl index efc0340..5050b70 100644 --- a/test/gpu/cuda_tests.jl +++ b/test/gpu/cuda_tests.jl @@ -51,6 +51,7 @@ end end prob = NonlinearProblem{false}(f, @SVector[1.0f0, 1.0f0]) + prob = convert(SimpleNonlinearSolve.ImmutableNonlinearProblem, prob) @testset "$(nameof(typeof(alg)))" for alg in ( SimpleNewtonRaphson(), SimpleDFSane(), SimpleTrustRegion(), From a791d82d3a24bd04c9b70d6f695397b32de5e844 Mon Sep 17 00:00:00 2001 From: Matt Bossart <67015312+m-bossart@users.noreply.github.com> Date: Sat, 8 Jun 2024 10:34:39 -0400 Subject: [PATCH 02/11] fix dispatch Co-authored-by: Christopher Rackauckas --- src/immutable_nonlinear_problem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/immutable_nonlinear_problem.jl b/src/immutable_nonlinear_problem.jl index 0ac98ef..ae0c116 100644 --- a/src/immutable_nonlinear_problem.jl +++ b/src/immutable_nonlinear_problem.jl @@ -60,7 +60,7 @@ staticarray_itize(x::Matrix) = SMatrix{size(x)...}(x) staticarray_itize(x::SizedMatrix) = SMatrix{size(x)...}(x) function Base.convert(::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem} - ImmutableNonlinearProblem(prob.f, + ImmutableNonlinearProblem{isinplace(prob)}(prob.f, staticarray_itize(prob.u0), staticarray_itize(prob.p), prob.problem_type; From 5afefcd70b1534fa7d9930f30d3ad8750139cfbd Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sat, 8 Jun 2024 15:49:52 -0400 Subject: [PATCH 03/11] dispatch on Union{NonlinearProble, ImmutableNonlinearProblem} --- src/SimpleNonlinearSolve.jl | 18 ++++++++++++++---- src/nlsolve/broyden.jl | 2 +- src/nlsolve/dfsane.jl | 2 +- src/nlsolve/halley.jl | 2 +- src/nlsolve/klement.jl | 2 +- src/nlsolve/lbroyden.jl | 6 +++--- src/nlsolve/raphson.jl | 2 +- src/nlsolve/trustRegion.jl | 2 +- src/utils.jl | 6 +++--- 9 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index fd8c0a5..f7f9ea2 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -38,7 +38,7 @@ abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorit abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end @inline __is_extension_loaded(::Val) = false - +include("immutable_nonlinear_problem.jl") include("utils.jl") include("linesearch.jl") @@ -61,7 +61,6 @@ include("bracketing/itp.jl") # AD include("ad.jl") -include("immutable_nonlinear_problem.jl") ## Default algorithm @@ -72,7 +71,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...; end # By Pass the highlevel checks for NonlinearProblem for Simple Algorithms -function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, +function SciMLBase.solve(prob::Union{NonlinearProblem}, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] @@ -83,7 +82,18 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSol p === nothing, alg, args...; prob.kwargs..., kwargs...) end -function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed, +function SciMLBase.solve(prob::Union{ImmutableNonlinearProblem}, alg::AbstractSimpleNonlinearSolveAlgorithm, + args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) +if sensealg === nothing && haskey(prob.kwargs, :sensealg) + sensealg = prob.kwargs[:sensealg] +end +new_u0 = u0 !== nothing ? u0 : prob.u0 +new_p = p !== nothing ? p : prob.p +return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p, + p === nothing, alg, args...; prob.kwargs..., kwargs...) +end + +function __internal_solve_up(_prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob return SciMLBase.__solve(prob, alg, args...; kwargs...) diff --git a/src/nlsolve/broyden.jl b/src/nlsolve/broyden.jl index 890578f..cbae8a4 100644 --- a/src/nlsolve/broyden.jl +++ b/src/nlsolve/broyden.jl @@ -22,7 +22,7 @@ end __get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS) -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...; +function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleBroyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/dfsane.jl b/src/nlsolve/dfsane.jl index 835ee4b..c25ed26 100644 --- a/src/nlsolve/dfsane.jl +++ b/src/nlsolve/dfsane.jl @@ -54,7 +54,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy) end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...; +function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleDFSane{M}, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) where {M} x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/halley.jl b/src/nlsolve/halley.jl index d3fe1cd..47d7614 100644 --- a/src/nlsolve/halley.jl +++ b/src/nlsolve/halley.jl @@ -24,7 +24,7 @@ A low-overhead implementation of Halley's Method. autodiff = nothing end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...; +function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleHalley, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/klement.jl b/src/nlsolve/klement.jl index 8041ef4..27d4f69 100644 --- a/src/nlsolve/klement.jl +++ b/src/nlsolve/klement.jl @@ -6,7 +6,7 @@ method is non-allocating on scalar and static array problems. """ struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...; +function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleKlement, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/lbroyden.jl b/src/nlsolve/lbroyden.jl index 1eeda4f..aefe122 100644 --- a/src/nlsolve/lbroyden.jl +++ b/src/nlsolve/lbroyden.jl @@ -29,7 +29,7 @@ function SimpleLimitedMemoryBroyden(; return SimpleLimitedMemoryBroyden{_unwrap_val(threshold), _unwrap_val(linesearch)}(alpha) end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden, +function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleLimitedMemoryBroyden, args...; termination_condition = nothing, kwargs...) if prob.u0 isa SArray if termination_condition === nothing || @@ -44,7 +44,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyd return __generic_solve(prob, alg, args...; termination_condition, kwargs...) end -@views function __generic_solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden, +@views function __generic_solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleLimitedMemoryBroyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) @@ -114,7 +114,7 @@ end # Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite # finicky, so we'll implement it separately from the generic version # Ignore termination_condition. Don't pass things into internal functions -function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden, +function __static_solve(prob::Union{NonlinearProblem{<:SArray}, ImmutableNonlinearProblem{<:SArray}}, alg::SimpleLimitedMemoryBroyden, args...; abstol = nothing, maxiters = 1000, kwargs...) x = prob.u0 fx = _get_fx(prob, x) diff --git a/src/nlsolve/raphson.jl b/src/nlsolve/raphson.jl index ca6864b..f9bb999 100644 --- a/src/nlsolve/raphson.jl +++ b/src/nlsolve/raphson.jl @@ -23,7 +23,7 @@ end const SimpleGaussNewton = SimpleNewtonRaphson -function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, +function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing, maxiters = 1000, termination_condition = nothing, alias_u0 = false, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/trustRegion.jl b/src/nlsolve/trustRegion.jl index 6ff263e..81e0377 100644 --- a/src/nlsolve/trustRegion.jl +++ b/src/nlsolve/trustRegion.jl @@ -55,7 +55,7 @@ scalar and static array problems. nlsolve_update_rule = Val(false) end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args...; +function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleTrustRegion, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/utils.jl b/src/utils.jl index 2e76a4b..7af0807 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -123,7 +123,7 @@ end error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype`") return _get_fx(prob.f, x, prob.p) end -@inline _get_fx(prob::NonlinearProblem, x) = _get_fx(prob.f, x, prob.p) +@inline _get_fx(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, x) = _get_fx(prob.f, x, prob.p) @inline function _get_fx(f::NonlinearFunction, x, p) if isinplace(f) if f.resid_prototype !== nothing @@ -145,7 +145,7 @@ end # different. NonlinearSolve is more for robust / cached solvers while SimpleNonlinearSolve # is meant for low overhead solvers, users can opt into the other termination modes but the # default is to use the least overhead version. -function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing) +function init_termination_cache(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, abstol, reltol, du, u, ::Nothing) return init_termination_cache( prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix1(maximum, abs))) end @@ -155,7 +155,7 @@ function init_termination_cache( prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix2(norm, 2))) end -function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, +function init_termination_cache(prob::Union{NonlinearProblem, ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode) T = promote_type(eltype(du), eltype(u)) abstol = __get_tolerance(u, abstol, T) From 783479a8fbf8fbf5b55e2f92df69aaf22b46b513 Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Mon, 10 Jun 2024 08:53:39 -0400 Subject: [PATCH 04/11] fix formatting --- src/SimpleNonlinearSolve.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index f7f9ea2..9dc50db 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -70,8 +70,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...; return solve(prob, ITP(), args...; prob.kwargs..., kwargs...) end -# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms -function SciMLBase.solve(prob::Union{NonlinearProblem}, alg::AbstractSimpleNonlinearSolveAlgorithm, +# Bypass the highlevel checks for NonlinearProblem for Simple Algorithms +function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] @@ -82,15 +82,15 @@ function SciMLBase.solve(prob::Union{NonlinearProblem}, alg::AbstractSimpleNonli p === nothing, alg, args...; prob.kwargs..., kwargs...) end -function SciMLBase.solve(prob::Union{ImmutableNonlinearProblem}, alg::AbstractSimpleNonlinearSolveAlgorithm, - args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) -if sensealg === nothing && haskey(prob.kwargs, :sensealg) - sensealg = prob.kwargs[:sensealg] -end -new_u0 = u0 !== nothing ? u0 : prob.u0 -new_p = p !== nothing ? p : prob.p -return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p, - p === nothing, alg, args...; prob.kwargs..., kwargs...) +function SciMLBase.solve(prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, + args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) + if sensealg === nothing && haskey(prob.kwargs, :sensealg) + sensealg = prob.kwargs[:sensealg] + end + new_u0 = u0 !== nothing ? u0 : prob.u0 + new_p = p !== nothing ? p : prob.p + return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p, + p === nothing, alg, args...; prob.kwargs..., kwargs...) end function __internal_solve_up(_prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, sensealg, u0, u0_changed, From 2bb415f241343638b982aeda31ff7b644a3d68ad Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Mon, 10 Jun 2024 09:43:55 -0400 Subject: [PATCH 05/11] convert in solve --- Project.toml | 3 +-- src/SimpleNonlinearSolve.jl | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index c53c17b..34369cb 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" -StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -63,7 +63,6 @@ SciMLBase = "2.37.0" SciMLSensitivity = "7.58" Setfield = "1.1.1" StaticArrays = "1.9" -StaticArraysCore = "1.4.2" Test = "1.10" Tracker = "0.2.33" Zygote = "0.6.69" diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 9dc50db..9558b3f 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -26,7 +26,7 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val, warn_paramtype using Setfield: @set! - using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size, SizedVector, SizedMatrix + using StaticArrays: StaticArray, SVector, SMatrix, SArray, MArray, Size, SizedVector, SizedMatrix end const DI = DifferentiationInterface @@ -73,6 +73,7 @@ end # Bypass the highlevel checks for NonlinearProblem for Simple Algorithms function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) + prob = convert(ImmutableNonlinearProblem, prob) if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] end From c4f5b3876071ab7720e55d4f7a368048ea006507 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 13 Jun 2024 05:53:38 -0400 Subject: [PATCH 06/11] Update test/gpu/cuda_tests.jl --- test/gpu/cuda_tests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/gpu/cuda_tests.jl b/test/gpu/cuda_tests.jl index 5050b70..efc0340 100644 --- a/test/gpu/cuda_tests.jl +++ b/test/gpu/cuda_tests.jl @@ -51,7 +51,6 @@ end end prob = NonlinearProblem{false}(f, @SVector[1.0f0, 1.0f0]) - prob = convert(SimpleNonlinearSolve.ImmutableNonlinearProblem, prob) @testset "$(nameof(typeof(alg)))" for alg in ( SimpleNewtonRaphson(), SimpleDFSane(), SimpleTrustRegion(), From a0dc59b38c3bf2e034ed598f9867b9856b8770ed Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sat, 29 Jun 2024 14:25:08 -0400 Subject: [PATCH 07/11] changes from SciML --- Project.toml | 2 +- src/SimpleNonlinearSolve.jl | 53 +++++++++++++++++-------------------- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/Project.toml b/Project.toml index 34369cb..5d88c24 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "1.10.0" +version = "1.10.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 9558b3f..be5201b 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -1,33 +1,30 @@ module SimpleNonlinearSolve -using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations - -@recompile_invalidations begin - using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff, - AutoPolyesterForwardDiff - using ArrayInterface: ArrayInterface - using ConcreteStructs: @concrete - using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode, - AbstractSafeNonlinearTerminationMode, - AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode, - NONLINEARSOLVE_DEFAULT_NORM - using DifferentiationInterface: DifferentiationInterface - using DiffResults: DiffResults - using FastClosures: @closure - using FiniteDiff: FiniteDiff - using ForwardDiff: ForwardDiff, Dual - using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess, lu, - mul!, norm, transpose - using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex - using Reexport: @reexport - using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, - AbstractNonlinearFunction, StandardNonlinearProblem, - NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem, - ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm, - build_solution, isinplace, _unwrap_val, warn_paramtype - using Setfield: @set! - using StaticArrays: StaticArray, SVector, SMatrix, SArray, MArray, Size, SizedVector, SizedMatrix -end +using PrecompileTools: @compile_workload, @setup_workload + +using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff, + AutoPolyesterForwardDiff +using ArrayInterface: ArrayInterface +using ConcreteStructs: @concrete +using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode, + AbstractSafeNonlinearTerminationMode, + AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode, + NONLINEARSOLVE_DEFAULT_NORM +using DifferentiationInterface: DifferentiationInterface +using DiffResults: DiffResults +using FastClosures: @closure +using FiniteDiff: FiniteDiff +using ForwardDiff: ForwardDiff, Dual +using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess, lu, mul!, + norm, transpose +using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex +using Reexport: @reexport +using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, + NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem, + ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm, + build_solution, isinplace, _unwrap_val +using Setfield: @set! +using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size const DI = DifferentiationInterface From 79cf8855ec90f70953ae7119065a9c51811c49c7 Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sat, 29 Jun 2024 15:27:14 -0400 Subject: [PATCH 08/11] address comments --- Project.toml | 2 +- src/SimpleNonlinearSolve.jl | 7 ++++--- src/immutable_nonlinear_problem.jl | 9 ++------- src/nlsolve/broyden.jl | 2 +- src/nlsolve/dfsane.jl | 2 +- src/nlsolve/halley.jl | 2 +- src/nlsolve/klement.jl | 2 +- src/nlsolve/lbroyden.jl | 6 +++--- src/nlsolve/trustRegion.jl | 2 +- src/utils.jl | 4 ++-- 10 files changed, 17 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index 5d88c24..c6acdfe 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index be5201b..66993dd 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -19,10 +19,11 @@ using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess norm, transpose using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex using Reexport: @reexport -using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, +using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, + AbstractNonlinearFunction, StandardNonlinearProblem, NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm, - build_solution, isinplace, _unwrap_val + build_solution, isinplace, _unwrap_val, warn_paramtype using Setfield: @set! using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size @@ -91,7 +92,7 @@ function SciMLBase.solve(prob::ImmutableNonlinearProblem, alg::AbstractSimpleNon p === nothing, alg, args...; prob.kwargs..., kwargs...) end -function __internal_solve_up(_prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, sensealg, u0, u0_changed, +function __internal_solve_up(_prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob return SciMLBase.__solve(prob, alg, args...; kwargs...) diff --git a/src/immutable_nonlinear_problem.jl b/src/immutable_nonlinear_problem.jl index ae0c116..856a8df 100644 --- a/src/immutable_nonlinear_problem.jl +++ b/src/immutable_nonlinear_problem.jl @@ -53,16 +53,11 @@ function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem) ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p) end -staticarray_itize(x) = x -staticarray_itize(x::Vector) = SVector{length(x)}(x) -staticarray_itize(x::SizedVector) = SVector{length(x)}(x) -staticarray_itize(x::Matrix) = SMatrix{size(x)...}(x) -staticarray_itize(x::SizedMatrix) = SMatrix{size(x)...}(x) function Base.convert(::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem} ImmutableNonlinearProblem{isinplace(prob)}(prob.f, - staticarray_itize(prob.u0), - staticarray_itize(prob.p), + prob.u0, + prob.p, prob.problem_type; prob.kwargs...) end diff --git a/src/nlsolve/broyden.jl b/src/nlsolve/broyden.jl index cbae8a4..2b02955 100644 --- a/src/nlsolve/broyden.jl +++ b/src/nlsolve/broyden.jl @@ -22,7 +22,7 @@ end __get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS) -function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleBroyden, args...; +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/dfsane.jl b/src/nlsolve/dfsane.jl index c25ed26..7460666 100644 --- a/src/nlsolve/dfsane.jl +++ b/src/nlsolve/dfsane.jl @@ -54,7 +54,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy) end -function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleDFSane{M}, args...; +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane{M}, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) where {M} x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/halley.jl b/src/nlsolve/halley.jl index 47d7614..bf2f5f8 100644 --- a/src/nlsolve/halley.jl +++ b/src/nlsolve/halley.jl @@ -24,7 +24,7 @@ A low-overhead implementation of Halley's Method. autodiff = nothing end -function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleHalley, args...; +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/klement.jl b/src/nlsolve/klement.jl index 27d4f69..2db3405 100644 --- a/src/nlsolve/klement.jl +++ b/src/nlsolve/klement.jl @@ -6,7 +6,7 @@ method is non-allocating on scalar and static array problems. """ struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end -function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleKlement, args...; +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/nlsolve/lbroyden.jl b/src/nlsolve/lbroyden.jl index aefe122..683e9e5 100644 --- a/src/nlsolve/lbroyden.jl +++ b/src/nlsolve/lbroyden.jl @@ -29,7 +29,7 @@ function SimpleLimitedMemoryBroyden(; return SimpleLimitedMemoryBroyden{_unwrap_val(threshold), _unwrap_val(linesearch)}(alpha) end -function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleLimitedMemoryBroyden, +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden, args...; termination_condition = nothing, kwargs...) if prob.u0 isa SArray if termination_condition === nothing || @@ -44,7 +44,7 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProbl return __generic_solve(prob, alg, args...; termination_condition, kwargs...) end -@views function __generic_solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleLimitedMemoryBroyden, +@views function __generic_solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) @@ -114,7 +114,7 @@ end # Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite # finicky, so we'll implement it separately from the generic version # Ignore termination_condition. Don't pass things into internal functions -function __static_solve(prob::Union{NonlinearProblem{<:SArray}, ImmutableNonlinearProblem{<:SArray}}, alg::SimpleLimitedMemoryBroyden, +function __static_solve(prob::ImmutableNonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden, args...; abstol = nothing, maxiters = 1000, kwargs...) x = prob.u0 fx = _get_fx(prob, x) diff --git a/src/nlsolve/trustRegion.jl b/src/nlsolve/trustRegion.jl index 81e0377..2e78baf 100644 --- a/src/nlsolve/trustRegion.jl +++ b/src/nlsolve/trustRegion.jl @@ -55,7 +55,7 @@ scalar and static array problems. nlsolve_update_rule = Val(false) end -function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleTrustRegion, args...; +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegion, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) diff --git a/src/utils.jl b/src/utils.jl index 7af0807..851af59 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -123,7 +123,7 @@ end error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype`") return _get_fx(prob.f, x, prob.p) end -@inline _get_fx(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, x) = _get_fx(prob.f, x, prob.p) +@inline _get_fx(prob::ImmutableNonlinearProblem, x) = _get_fx(prob.f, x, prob.p) @inline function _get_fx(f::NonlinearFunction, x, p) if isinplace(f) if f.resid_prototype !== nothing @@ -145,7 +145,7 @@ end # different. NonlinearSolve is more for robust / cached solvers while SimpleNonlinearSolve # is meant for low overhead solvers, users can opt into the other termination modes but the # default is to use the least overhead version. -function init_termination_cache(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, abstol, reltol, du, u, ::Nothing) +function init_termination_cache(prob::ImmutableNonlinearProblem, abstol, reltol, du, u, ::Nothing) return init_termination_cache( prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix1(maximum, abs))) end From e8196a276716574bbe2c2a186629557822243499 Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sat, 29 Jun 2024 15:29:36 -0400 Subject: [PATCH 09/11] another union --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 851af59..d8343a9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -155,7 +155,7 @@ function init_termination_cache( prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix2(norm, 2))) end -function init_termination_cache(prob::Union{NonlinearProblem, ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, +function init_termination_cache(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode) T = promote_type(eltype(du), eltype(u)) abstol = __get_tolerance(u, abstol, T) From dc37c148202541ebd171e5af9e5aa182529af844 Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sat, 29 Jun 2024 15:30:58 -0400 Subject: [PATCH 10/11] another union --- src/nlsolve/raphson.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nlsolve/raphson.jl b/src/nlsolve/raphson.jl index f9bb999..d7e72db 100644 --- a/src/nlsolve/raphson.jl +++ b/src/nlsolve/raphson.jl @@ -23,7 +23,7 @@ end const SimpleGaussNewton = SimpleNewtonRaphson -function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, +function SciMLBase.__solve(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing, maxiters = 1000, termination_condition = nothing, alias_u0 = false, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) From d7a1dec98fc5753e802ea0e5181af9a4b2d344ea Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sat, 29 Jun 2024 15:44:34 -0400 Subject: [PATCH 11/11] StaticArraysCore compat --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index c6acdfe..c27d749 100644 --- a/Project.toml +++ b/Project.toml @@ -63,6 +63,7 @@ SciMLBase = "2.37.0" SciMLSensitivity = "7.58" Setfield = "1.1.1" StaticArrays = "1.9" +StaticArraysCore = "1.4.2" Test = "1.10" Tracker = "0.2.33" Zygote = "0.6.69"