diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index f0459447b..25fe50023 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -27,6 +27,4 @@ export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTermi AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode, RelNormSafeNormTerminationMode, AbsNormSafeNormTerminationMode -export ImmutableNonlinearProblem - end diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index e2c92c016..5d8817b60 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -7,17 +7,35 @@ version = "1.13.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e" +CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[extensions] +SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore" +SimpleNonlinearSolveDiffEqBaseExt = "DiffEqBase" +SimpleNonlinearSolveReverseDiffExt = "ReverseDiff" +SimpleNonlinearSolveTrackerExt = "Tracker" + [compat] ADTypes = "1.2" ArrayInterface = "7.16" BracketingNonlinearSolve = "1" +ChainRulesCore = "1.24" +CommonSolve = "0.2.4" +DiffEqBase = "6.155" NonlinearSolveBase = "1" PrecompileTools = "1.2" Reexport = "1.2" +ReverseDiff = "1.15" SciMLBase = "2.50" +Tracker = "0.2.35" julia = "1.10" diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl new file mode 100644 index 000000000..df0bd7573 --- /dev/null +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl @@ -0,0 +1,23 @@ +module SimpleNonlinearSolveChainRulesCoreExt + +using ChainRulesCore: ChainRulesCore, NoTangent +using NonlinearSolveBase: ImmutableNonlinearProblem +using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem + +using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up, + solve_adjoint + +function ChainRulesCore.rrule(::typeof(simplenonlinearsolve_solve_up), + prob::Union{InternalNonlinearProblem, NonlinearLeastSquaresProblem}, + sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) + out, ∇internal = solve_adjoint( + prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...) + function ∇simplenonlinearsolve_solve_up(Δ) + ∂f, ∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Δ) + return ( + ∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args...) + end + return out, ∇simplenonlinearsolve_solve_up +end + +end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl new file mode 100644 index 000000000..950a04019 --- /dev/null +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl @@ -0,0 +1,11 @@ +module SimpleNonlinearSolveDiffEqBaseExt + +using DiffEqBase: DiffEqBase + +using SimpleNonlinearSolve: SimpleNonlinearSolve + +function SimpleNonlinearSolve.solve_adjoint_internal(args...; kwargs...) + return DiffEqBase._solve_adjoint(args...; kwargs...) +end + +end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl new file mode 100644 index 000000000..1357bec83 --- /dev/null +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl @@ -0,0 +1,37 @@ +module SimpleNonlinearSolveReverseDiffExt + +using ArrayInterface: ArrayInterface +using NonlinearSolveBase: ImmutableNonlinearProblem +using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal +using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem, remake + +using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint + +for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem) + aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any) + for (uT, pT) in collect(Iterators.product(aTypes, aTypes))[1:(end - 1)] + @eval function SimpleNonlinearSolve.simplenonlinearsolve_solve_up( + prob::$(pType), sensealg, u0::$(uT), u0_changed, + p::$(pT), p_changed, alg, args...; kwargs...) + return ReverseDiff.track(SimpleNonlinearSolve.simplenonlinearsolve_solve_up, + prob, sensealg, ArrayInterface.aos_to_soa(u0), true, + ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) + end + end + + @eval ReverseDiff.@grad function SimpleNonlinearSolve.simplenonlinearsolve_solve_up( + tprob::$(pType), sensealg, tu0, u0_changed, + tp, p_changed, alg, args...; kwargs...) + u0, p = ReverseDiff.value(tu0), ReverseDiff.value(tp) + prob = remake(tprob; u0, p) + out, ∇internal = solve_adjoint( + prob, sensealg, u0, p, ReverseDiffOriginator(), alg, args...; kwargs...) + + function ∇simplenonlinearsolve_solve_up(Δ...) + ∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Δ...) + return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) + end + end +end + +end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl new file mode 100644 index 000000000..935484db1 --- /dev/null +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl @@ -0,0 +1,37 @@ +module SimpleNonlinearSolveTrackerExt + +using ArrayInterface: ArrayInterface +using NonlinearSolveBase: ImmutableNonlinearProblem +using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake +using Tracker: Tracker, TrackedArray, TrackedReal + +using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint + +for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem) + aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any) + for (uT, pT) in collect(Iterators.product(aTypes, aTypes))[1:(end - 1)] + @eval function SimpleNonlinearSolve.simplenonlinearsolve_solve_up( + prob::$(pType), sensealg, u0::$(uT), u0_changed, + p::$(pT), p_changed, alg, args...; kwargs...) + return Tracker.track(SimpleNonlinearSolve.simplenonlinearsolve_solve_up, prob, + sensealg, ArrayInterface.aos_to_soa(u0), true, + ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) + end + end + + @eval Tracker.@grad function SimpleNonlinearSolve.simplenonlinearsolve_solve_up( + tprob::$(pType), sensealg, tu0, u0_changed, + tp, p_changed, alg, args...; kwargs...) + u0, p = Tracker.data(tu0), Tracker.data(tp) + prob = remake(tprob; u0, p) + out, ∇internal = solve_adjoint( + prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...) + + function ∇simplenonlinearsolve_solve_up(Δ) + ∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Tracker.data(Δ)) + return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) + end + end +end + +end diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 0debfd328..99c0be844 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -2,11 +2,53 @@ module SimpleNonlinearSolve using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff +using CommonSolve: CommonSolve, solve using PrecompileTools: @compile_workload, @setup_workload using Reexport: @reexport @reexport using SciMLBase # I don't like this but needed to avoid a breaking change +using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder +using NonlinearSolveBase: ImmutableNonlinearProblem + +abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end + +is_extension_loaded(::Val) = false + +# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms +function CommonSolve.solve(prob::NonlinearProblem, + alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) + prob = convert(ImmutableNonlinearProblem, prob) + return solve(prob, alg, args...; kwargs...) +end + +function CommonSolve.solve( + prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, + args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) + if sensealg === nothing && haskey(prob.kwargs, :sensealg) + sensealg = prob.kwargs[:sensealg] + end + new_u0 = u0 !== nothing ? u0 : prob.u0 + new_p = p !== nothing ? p : prob.p + return simplenonlinearsolve_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p, + p === nothing, alg, args...; prob.kwargs..., kwargs...) +end + +function simplenonlinearsolve_solve_up( + prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed, p, p_changed, + alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) + (u0_changed || p_changed) && (prob = remake(prob; u0, p)) + return SciMLBase.__solve(prob, alg, args...; kwargs...) +end + +# NOTE: This is defined like this so that we don't have to keep have 2 args for the +# extensions +function solve_adjoint(args...; kws...) + is_extension_loaded(Val(:DiffEqBase)) && return solve_adjoint_internal(args...; kws...) + error("Adjoint sensitivity analysis requires `DiffEqBase.jl` to be explicitly loaded.") +end + +function solve_adjoint_internal end @setup_workload begin @compile_workload begin end