From 8790d799f06669cfe31cf974b5e44d856e7d5c07 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Sep 2024 15:07:12 -0400 Subject: [PATCH] feat: add the AD workflows --- lib/SimpleNonlinearSolve/Project.toml | 6 ++++ .../src/SimpleNonlinearSolve.jl | 31 +++++++++++++++++-- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index 5d8817b60..f4ae4e4de 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -8,6 +8,9 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -32,6 +35,9 @@ BracketingNonlinearSolve = "1" ChainRulesCore = "1.24" CommonSolve = "0.2.4" DiffEqBase = "6.155" +DifferentiationInterface = "0.5.17" +FiniteDiff = "2.24.0" +ForwardDiff = "0.10.36" NonlinearSolveBase = "1" PrecompileTools = "1.2" Reexport = "1.2" diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 99c0be844..f1d7b0713 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -1,16 +1,25 @@ module SimpleNonlinearSolve -using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff, - AutoPolyesterForwardDiff using CommonSolve: CommonSolve, solve using PrecompileTools: @compile_workload, @setup_workload using Reexport: @reexport @reexport using SciMLBase # I don't like this but needed to avoid a breaking change using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode +# AD Dependencies +using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff, + AutoPolyesterForwardDiff +using DifferentiationInterface: DifferentiationInterface +# TODO: move these to extensions in a breaking change. These are not even used in the +# package, but are used to trigger the extension loading in DI.jl +using FiniteDiff: FiniteDiff +using ForwardDiff: ForwardDiff + using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder using NonlinearSolveBase: ImmutableNonlinearProblem +const DI = DifferentiationInterface + abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end is_extension_loaded(::Val) = false @@ -51,7 +60,23 @@ end function solve_adjoint_internal end @setup_workload begin - @compile_workload begin end + for T in (Float32, Float64) + prob_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) + prob_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, ones(T, 3), T(2)) + prob_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, ones(T, 3), T(2)) + + algs = [] + algs_no_iip = [] + + @compile_workload begin + for alg in algs, prob in (prob_scalar, prob_iip, prob_oop) + CommonSolve.solve(prob, alg) + end + for alg in algs_no_iip + CommonSolve.solve(prob_scalar, alg) + end + end + end end export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff