Skip to content
This repository has been archived by the owner on Oct 31, 2024. It is now read-only.

convert to immutable for CUDA tests #151

Closed
wants to merge 12 commits into from
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.10.0"
version = "1.10.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
71 changes: 41 additions & 30 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
module SimpleNonlinearSolve

using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations

@recompile_invalidations begin
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode,
NONLINEARSOLVE_DEFAULT_NORM
using DifferentiationInterface: DifferentiationInterface
using DiffResults: DiffResults
using FastClosures: @closure
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff, Dual
using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess, lu,
mul!, norm, transpose
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
using Reexport: @reexport
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
end
using PrecompileTools: @compile_workload, @setup_workload

using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode,
NONLINEARSOLVE_DEFAULT_NORM
using DifferentiationInterface: DifferentiationInterface
using DiffResults: DiffResults
using FastClosures: @closure
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff, Dual
using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess, lu, mul!,
norm, transpose
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
using Reexport: @reexport
using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
AbstractNonlinearFunction, StandardNonlinearProblem,
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val, warn_paramtype
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size

const DI = DifferentiationInterface

Expand All @@ -37,7 +36,7 @@ abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorit
abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end

@inline __is_extension_loaded(::Val) = false

include("immutable_nonlinear_problem.jl")
include("utils.jl")
include("linesearch.jl")

Expand Down Expand Up @@ -69,9 +68,21 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...;
return solve(prob, ITP(), args...; prob.kwargs..., kwargs...)
end

# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
# Bypass the highlevel checks for NonlinearProblem for Simple Algorithms
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
prob = convert(ImmutableNonlinearProblem, prob)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
new_u0 = u0 !== nothing ? u0 : prob.u0
new_p = p !== nothing ? p : prob.p
return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p,
p === nothing, alg, args...; prob.kwargs..., kwargs...)
end

function SciMLBase.solve(prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
Expand All @@ -81,7 +92,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSol
p === nothing, alg, args...; prob.kwargs..., kwargs...)
end

function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed,
function __internal_solve_up(_prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
return SciMLBase.__solve(prob, alg, args...; kwargs...)
Expand Down
63 changes: 63 additions & 0 deletions src/immutable_nonlinear_problem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
struct ImmutableNonlinearProblem{uType, isinplace, P, F, K, PT} <:
AbstractNonlinearProblem{uType, isinplace}
f::F
u0::uType
p::P
problem_type::PT
kwargs::K
@add_kwonly function ImmutableNonlinearProblem{iip}(f::AbstractNonlinearFunction{iip}, u0,
p = NullParameters(),
problem_type = StandardNonlinearProblem();
kwargs...) where {iip}
if haskey(kwargs, :p)
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to `NonlinearProblem`. This is not supported.")
end
warn_paramtype(p)
new{typeof(u0), iip, typeof(p), typeof(f),
typeof(kwargs), typeof(problem_type)}(f,
u0,
p,
problem_type,
kwargs)
end

"""

Define a steady state problem using the given function.
`isinplace` optionally sets whether the function is inplace or not.
This is determined automatically, but not inferred.
"""
function ImmutableNonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip}
ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
end
end

"""

Define a nonlinear problem using an instance of
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
"""
function ImmutableNonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
end

function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
end

"""

Define a ImmutableNonlinearProblem problem from SteadyStateProblem
"""
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p)
end


function Base.convert(::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
ImmutableNonlinearProblem{isinplace(prob)}(prob.f,
prob.u0,
prob.p,
prob.problem_type;
prob.kwargs...)
end
2 changes: 1 addition & 1 deletion src/nlsolve/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ end

__get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS)

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real =
σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane{M}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
termination_condition = nothing, kwargs...) where {M}
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ A low-overhead implementation of Halley's Method.
autodiff = nothing
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ method is non-allocating on scalar and static array problems.
"""
struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
6 changes: 3 additions & 3 deletions src/nlsolve/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function SimpleLimitedMemoryBroyden(;
return SimpleLimitedMemoryBroyden{_unwrap_val(threshold), _unwrap_val(linesearch)}(alpha)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden,
args...; termination_condition = nothing, kwargs...)
if prob.u0 isa SArray
if termination_condition === nothing ||
Expand All @@ -44,7 +44,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyd
return __generic_solve(prob, alg, args...; termination_condition, kwargs...)
end

@views function __generic_solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
@views function __generic_solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden,
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down Expand Up @@ -114,7 +114,7 @@ end
# Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite
# finicky, so we'll implement it separately from the generic version
# Ignore termination_condition. Don't pass things into internal functions
function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
function __static_solve(prob::ImmutableNonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
args...; abstol = nothing, maxiters = 1000, kwargs...)
x = prob.u0
fx = _get_fx(prob, x)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end

const SimpleGaussNewton = SimpleNewtonRaphson

function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
function SciMLBase.__solve(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing,
maxiters = 1000, termination_condition = nothing, alias_u0 = false, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ scalar and static array problems.
nlsolve_update_rule = Val(false)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegion, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ end
error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype`")
return _get_fx(prob.f, x, prob.p)
end
@inline _get_fx(prob::NonlinearProblem, x) = _get_fx(prob.f, x, prob.p)
@inline _get_fx(prob::ImmutableNonlinearProblem, x) = _get_fx(prob.f, x, prob.p)
@inline function _get_fx(f::NonlinearFunction, x, p)
if isinplace(f)
if f.resid_prototype !== nothing
Expand All @@ -145,7 +145,7 @@ end
# different. NonlinearSolve is more for robust / cached solvers while SimpleNonlinearSolve
# is meant for low overhead solvers, users can opt into the other termination modes but the
# default is to use the least overhead version.
function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing)
function init_termination_cache(prob::ImmutableNonlinearProblem, abstol, reltol, du, u, ::Nothing)
return init_termination_cache(
prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix1(maximum, abs)))
end
Expand All @@ -155,7 +155,7 @@ function init_termination_cache(
prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix2(norm, 2)))
end

function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
function init_termination_cache(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
T = promote_type(eltype(du), eltype(u))
abstol = __get_tolerance(u, abstol, T)
Expand Down