Skip to content

Commit

Permalink
feat: add the AD workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 17, 2024
1 parent 8153008 commit 8790d79
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
6 changes: 6 additions & 0 deletions lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -32,6 +35,9 @@ BracketingNonlinearSolve = "1"
ChainRulesCore = "1.24"
CommonSolve = "0.2.4"
DiffEqBase = "6.155"
DifferentiationInterface = "0.5.17"
FiniteDiff = "2.24.0"
ForwardDiff = "0.10.36"
NonlinearSolveBase = "1"
PrecompileTools = "1.2"
Reexport = "1.2"
Expand Down
31 changes: 28 additions & 3 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
module SimpleNonlinearSolve

using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
AutoPolyesterForwardDiff
using CommonSolve: CommonSolve, solve
using PrecompileTools: @compile_workload, @setup_workload
using Reexport: @reexport
@reexport using SciMLBase # I don't like this but needed to avoid a breaking change
using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode

# AD Dependencies
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
AutoPolyesterForwardDiff
using DifferentiationInterface: DifferentiationInterface
# TODO: move these to extensions in a breaking change. These are not even used in the
# package, but are used to trigger the extension loading in DI.jl
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff

using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
using NonlinearSolveBase: ImmutableNonlinearProblem

const DI = DifferentiationInterface

abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end

is_extension_loaded(::Val) = false
Expand Down Expand Up @@ -51,7 +60,23 @@ end
function solve_adjoint_internal end

@setup_workload begin
@compile_workload begin end
for T in (Float32, Float64)
prob_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
prob_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, ones(T, 3), T(2))
prob_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, ones(T, 3), T(2))

algs = []
algs_no_iip = []

@compile_workload begin
for alg in algs, prob in (prob_scalar, prob_iip, prob_oop)
CommonSolve.solve(prob, alg)
end
for alg in algs_no_iip
CommonSolve.solve(prob_scalar, alg)
end
end
end
end

export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
Expand Down

0 comments on commit 8790d79

Please sign in to comment.