Skip to content

Commit

Permalink
feat: add simplenonlinearsolve AD specific dispatches
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 17, 2024
1 parent c238914 commit 8153008
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 2 deletions.
2 changes: 0 additions & 2 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,4 @@ export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTermi
AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode,
RelNormSafeNormTerminationMode, AbsNormSafeNormTerminationMode

export ImmutableNonlinearProblem

end
18 changes: 18 additions & 0 deletions lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,35 @@ version = "1.13.0"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolveDiffEqBaseExt = "DiffEqBase"
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
SimpleNonlinearSolveTrackerExt = "Tracker"

[compat]
ADTypes = "1.2"
ArrayInterface = "7.16"
BracketingNonlinearSolve = "1"
ChainRulesCore = "1.24"
CommonSolve = "0.2.4"
DiffEqBase = "6.155"
NonlinearSolveBase = "1"
PrecompileTools = "1.2"
Reexport = "1.2"
ReverseDiff = "1.15"
SciMLBase = "2.50"
Tracker = "0.2.35"
julia = "1.10"
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module SimpleNonlinearSolveChainRulesCoreExt

using ChainRulesCore: ChainRulesCore, NoTangent
using NonlinearSolveBase: ImmutableNonlinearProblem
using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem

using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up,
solve_adjoint

function ChainRulesCore.rrule(::typeof(simplenonlinearsolve_solve_up),
prob::Union{InternalNonlinearProblem, NonlinearLeastSquaresProblem},
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
out, ∇internal = solve_adjoint(
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)
function ∇simplenonlinearsolve_solve_up(Δ)
∂f, ∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Δ)
return (
∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args...)
end
return out, ∇simplenonlinearsolve_solve_up
end

end
11 changes: 11 additions & 0 deletions lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module SimpleNonlinearSolveDiffEqBaseExt

using DiffEqBase: DiffEqBase

using SimpleNonlinearSolve: SimpleNonlinearSolve

function SimpleNonlinearSolve.solve_adjoint_internal(args...; kwargs...)
return DiffEqBase._solve_adjoint(args...; kwargs...)
end

end
37 changes: 37 additions & 0 deletions lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
module SimpleNonlinearSolveReverseDiffExt

using ArrayInterface: ArrayInterface
using NonlinearSolveBase: ImmutableNonlinearProblem
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem, remake

using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint

for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem)
aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any)
for (uT, pT) in collect(Iterators.product(aTypes, aTypes))[1:(end - 1)]
@eval function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
prob::$(pType), sensealg, u0::$(uT), u0_changed,
p::$(pT), p_changed, alg, args...; kwargs...)
return ReverseDiff.track(SimpleNonlinearSolve.simplenonlinearsolve_solve_up,
prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end
end

@eval ReverseDiff.@grad function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
tprob::$(pType), sensealg, tu0, u0_changed,
tp, p_changed, alg, args...; kwargs...)
u0, p = ReverseDiff.value(tu0), ReverseDiff.value(tp)
prob = remake(tprob; u0, p)
out, ∇internal = solve_adjoint(
prob, sensealg, u0, p, ReverseDiffOriginator(), alg, args...; kwargs...)

function ∇simplenonlinearsolve_solve_up...)
∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal...)
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
end
end
end

end
37 changes: 37 additions & 0 deletions lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
module SimpleNonlinearSolveTrackerExt

using ArrayInterface: ArrayInterface
using NonlinearSolveBase: ImmutableNonlinearProblem
using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake
using Tracker: Tracker, TrackedArray, TrackedReal

using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint

for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem)
aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any)
for (uT, pT) in collect(Iterators.product(aTypes, aTypes))[1:(end - 1)]
@eval function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
prob::$(pType), sensealg, u0::$(uT), u0_changed,
p::$(pT), p_changed, alg, args...; kwargs...)
return Tracker.track(SimpleNonlinearSolve.simplenonlinearsolve_solve_up, prob,
sensealg, ArrayInterface.aos_to_soa(u0), true,
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end
end

@eval Tracker.@grad function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
tprob::$(pType), sensealg, tu0, u0_changed,
tp, p_changed, alg, args...; kwargs...)
u0, p = Tracker.data(tu0), Tracker.data(tp)
prob = remake(tprob; u0, p)
out, ∇internal = solve_adjoint(
prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...)

function ∇simplenonlinearsolve_solve_up(Δ)
∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Tracker.data(Δ))
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
end
end
end

end
42 changes: 42 additions & 0 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,53 @@ module SimpleNonlinearSolve

using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
AutoPolyesterForwardDiff
using CommonSolve: CommonSolve, solve
using PrecompileTools: @compile_workload, @setup_workload
using Reexport: @reexport
@reexport using SciMLBase # I don't like this but needed to avoid a breaking change
using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode

using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
using NonlinearSolveBase: ImmutableNonlinearProblem

abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end

is_extension_loaded(::Val) = false

# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
function CommonSolve.solve(prob::NonlinearProblem,
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
prob = convert(ImmutableNonlinearProblem, prob)
return solve(prob, alg, args...; kwargs...)
end

function CommonSolve.solve(
prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
new_u0 = u0 !== nothing ? u0 : prob.u0
new_p = p !== nothing ? p : prob.p
return simplenonlinearsolve_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p,
p === nothing, alg, args...; prob.kwargs..., kwargs...)
end

function simplenonlinearsolve_solve_up(
prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed, p, p_changed,
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
(u0_changed || p_changed) && (prob = remake(prob; u0, p))
return SciMLBase.__solve(prob, alg, args...; kwargs...)
end

# NOTE: This is defined like this so that we don't have to keep have 2 args for the
# extensions
function solve_adjoint(args...; kws...)
is_extension_loaded(Val(:DiffEqBase)) && return solve_adjoint_internal(args...; kws...)
error("Adjoint sensitivity analysis requires `DiffEqBase.jl` to be explicitly loaded.")
end

function solve_adjoint_internal end

@setup_workload begin
@compile_workload begin end
Expand Down

0 comments on commit 8153008

Please sign in to comment.