From 9ea11473f977da1a3d4307041239058c0f8f6ad6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 28 Oct 2024 23:57:12 -0400 Subject: [PATCH 1/2] feat: use specialization in NonlinearProblems --- Project.toml | 2 +- src/norecompile.jl | 19 ++++++++++++ src/solve.jl | 73 +++++++++++++++++++++++++++++++--------------- 3 files changed, 69 insertions(+), 25 deletions(-) diff --git a/Project.toml b/Project.toml index a85c81a24..182bfb373 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffEqBase" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" authors = ["Chris Rackauckas "] -version = "6.159.0" +version = "6.160.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/norecompile.jl b/src/norecompile.jl index a097fa0f0..da39fc74b 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -73,6 +73,25 @@ function wrapfun_iip(ff, FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt) end +function wrapfun_iip(ff, + inputs::Tuple{T1, T2, T3}) where {T1, T2, T3} + T = eltype(T2) + dualT = dualgen(T) + dualT1 = ArrayInterface.promote_eltype(T1, dualT) + dualT2 = ArrayInterface.promote_eltype(T2, dualT) + + iip_arglists = (Tuple{T1, T2, T3}, + Tuple{dualT1, dualT2, T3}, + Tuple{dualT1, T2, T3}) + + iip_returnlists = ntuple(x -> Nothing, 3) + + fwt = map(iip_arglists, iip_returnlists) do A, R + FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff)) + end + FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt) +end + const iip_arglists_default = ( Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}, diff --git a/src/solve.jl b/src/solve.jl index f6485d100..5533c0af0 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1152,14 +1152,20 @@ function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) p = get_concrete_p(prob, kwargs) u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) u0 = promote_u0(u0, p, nothing) - remake(prob; u0 = u0, p = p) + f_promote = promote_f( + prob.f, Val(SciMLBase.specialization(prob.f)), u0, p + ) + remake(prob; u0 = u0, p = p, f = f_promote) end function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...) p = get_concrete_p(prob, kwargs) u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) u0 = promote_u0(u0, p, nothing) - remake(prob; u0 = u0, p = p) + f_promote = promote_f( + prob.f, Val(SciMLBase.specialization(prob.f)), u0, p + ) + remake(prob; u0 = u0, p = p, f = f_promote) end function get_concrete_problem(prob::AbstractEnsembleProblem, isadapt; kwargs...) @@ -1252,28 +1258,47 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t) where {F, specialize} f = @set f.jac_prototype = similar(f.jac_prototype, uElType) end - @static if VERSION >= v"1.8-" - f = if f isa ODEFunction && isinplace(f) && !(f.f isa AbstractSciMLOperator) && - # Some reinitialization code still uses NLSolvers stuff which doesn't - # properly tag, so opt-out if potentially a mass matrix DAE - f.mass_matrix isa UniformScaling && - # Jacobians don't wrap, so just ignore those cases - f.jac === nothing && - ((specialize === SciMLBase.AutoSpecialize && eltype(u0) !== Any && - RecursiveArrayTools.recursive_unitless_eltype(u0) === eltype(u0) && - one(t) === oneunit(t) && - hasmethod(ArrayInterface.promote_eltype, - Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) && - hasmethod(promote_rule, - Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) && - hasmethod(promote_rule, - Tuple{Type{eltype(u0)}, Type{typeof(t)}})) || - (specialize === SciMLBase.FunctionWrapperSpecialize && - !(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper))) - return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t))) - else - return f - end + f = if f isa ODEFunction && isinplace(f) && !(f.f isa AbstractSciMLOperator) && + # Some reinitialization code still uses NLSolvers stuff which doesn't + # properly tag, so opt-out if potentially a mass matrix DAE + f.mass_matrix isa UniformScaling && + # Jacobians don't wrap, so just ignore those cases + f.jac === nothing && + ((specialize === SciMLBase.AutoSpecialize && eltype(u0) !== Any && + RecursiveArrayTools.recursive_unitless_eltype(u0) === eltype(u0) && + one(t) === oneunit(t) && + hasmethod(ArrayInterface.promote_eltype, + Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) && + hasmethod(promote_rule, + Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) && + hasmethod(promote_rule, + Tuple{Type{eltype(u0)}, Type{typeof(t)}})) || + (specialize === SciMLBase.FunctionWrapperSpecialize && + !(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper))) + return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t))) + else + return f + end +end + +function promote_f(f::NonlinearFunction, ::Val{specialize}, u0, p) where {specialize} + # Ensure our jacobian will be of the same type as u0 + uElType = u0 === nothing ? Float64 : eltype(u0) + if isdefined(f, :jac_prototype) && f.jac_prototype isa AbstractArray + f = @set f.jac_prototype = similar(f.jac_prototype, uElType) + end + + f = if isinplace(f) && !(f.f isa AbstractSciMLOperator) && + f.jac === nothing && + ((specialize === SciMLBase.AutoSpecialize && eltype(u0) !== Any && + RecursiveArrayTools.recursive_unitless_eltype(u0) === eltype(u0) && + hasmethod(ArrayInterface.promote_eltype, + Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) && + hasmethod(promote_rule, + Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}})) || + (specialize === SciMLBase.FunctionWrapperSpecialize && + !(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper))) + return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p))) else return f end From 8cd8c7f0a2ac0d5ca146c16bbdc4a8ad291f162d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 29 Oct 2024 18:44:23 -0400 Subject: [PATCH 2/2] fix: conditionally remake --- src/solve.jl | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 5533c0af0..a380675c2 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1152,20 +1152,24 @@ function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) p = get_concrete_p(prob, kwargs) u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) u0 = promote_u0(u0, p, nothing) - f_promote = promote_f( - prob.f, Val(SciMLBase.specialization(prob.f)), u0, p - ) - remake(prob; u0 = u0, p = p, f = f_promote) + f_promote = promote_f(prob.f, Val(SciMLBase.specialization(prob.f)), u0, p) + if f_promote === prob.f && u0 === prob.u0 && p === prob.p + return prob + else + return remake(prob; u0 = u0, p = p, f = f_promote) + end end function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...) p = get_concrete_p(prob, kwargs) u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) u0 = promote_u0(u0, p, nothing) - f_promote = promote_f( - prob.f, Val(SciMLBase.specialization(prob.f)), u0, p - ) - remake(prob; u0 = u0, p = p, f = f_promote) + f_promote = promote_f(prob.f, Val(SciMLBase.specialization(prob.f)), u0, p) + if f_promote === prob.f && u0 === prob.u0 && p === prob.p + return prob + else + return remake(prob; u0 = u0, p = p, f = f_promote) + end end function get_concrete_problem(prob::AbstractEnsembleProblem, isadapt; kwargs...)