Skip to content

Commit

Permalink
Merge pull request #956 from SciML/ad_piracy
Browse files Browse the repository at this point in the history
Remove AD piracy functions by moving to SciMLBase
  • Loading branch information
ChrisRackauckas authored Nov 3, 2023
2 parents 8262b58 + 042f08e commit c9ad77c
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 352 deletions.
25 changes: 9 additions & 16 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ version = "6.136.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Expand All @@ -25,7 +23,6 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand All @@ -35,9 +32,9 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
Expand All @@ -47,7 +44,6 @@ MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DiffEqBaseDistributionsExt = "Distributions"
Expand All @@ -59,7 +55,6 @@ DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
DiffEqBaseReverseDiffExt = "ReverseDiff"
DiffEqBaseTrackerExt = "Tracker"
DiffEqBaseUnitfulExt = "Unitful"
DiffEqBaseZygoteExt = "Zygote"

[compat]
ArrayInterface = "7"
Expand All @@ -73,31 +68,30 @@ FastBroadcast = "0.2"
ForwardDiff = "0.10"
FunctionWrappers = "1.0"
FunctionWrappersWrappers = "0.1"
LinearAlgebra = "1.6"
Logging = "1.6"
Markdown = "1.6"
LinearAlgebra = "1.9"
Logging = "1.9"
Markdown = "1.9"
MuladdMacro = "0.2.1"
Parameters = "0.12.0"
PreallocationTools = "0.4"
PrecompileTools = "1"
Printf = "1.6"
Printf = "1.9"
RecursiveArrayTools = "2"
Reexport = "1.0"
Requires = "1.0"
SciMLBase = "2.4.0"
SciMLBase = "2.7.0"
SciMLOperators = "0.2, 0.3"
Setfield = "0.8, 1"
SparseArrays = "1.6"
SparseArrays = "1.9"
Static = "0.7, 0.8"
StaticArraysCore = "1.4"
Statistics = "1"
Tricks = "0.1.6"
TruncatedStacktraces = "1"
ZygoteRules = "0.2"
julia = "1.6"
julia = "1.9"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -116,7 +110,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Distributed", "GeneralizedGenerated", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua"]
28 changes: 28 additions & 0 deletions ext/DiffEqBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module DiffEqBaseChainRulesCoreExt

using DiffEqBase
import DiffEqBase: numargs

import ChainRulesCore
import ChainRulesCore: NoTangent

ChainRulesCore.rrule(::typeof(numargs), f) = (numargs(f), df -> (NoTangent(), NoTangent()))
ChainRulesCore.@non_differentiable checkkwargs(kwargshandle)

function ChainRulesCore.frule(::typeof(solve_up), prob,
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
u0, p, args...;
kwargs...)
_solve_forward(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
kwargs...)
end

function ChainRulesCore.rrule(::typeof(solve_up), prob::AbstractDEProblem,
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
u0, p, args...;
kwargs...)
_solve_adjoint(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
kwargs...)
end

end
60 changes: 0 additions & 60 deletions ext/DiffEqBaseZygoteExt.jl

This file was deleted.

16 changes: 1 addition & 15 deletions src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ if isdefined(Base, :Experimental) &&
isdefined(Base.Experimental, Symbol("@max_methods"))
@eval Base.Experimental.@max_methods 1
end
if !isdefined(Base, :get_extension)
using Requires
end

import PrecompileTools

Expand All @@ -28,14 +25,10 @@ PrecompileTools.@recompile_invalidations begin

using Static: reduce_tup

import ChainRulesCore
import RecursiveArrayTools
import SparseArrays
import TruncatedStacktraces

import ChainRulesCore: NoTangent, @non_differentiable
import ZygoteRules

using Setfield

using ForwardDiff
Expand Down Expand Up @@ -140,13 +133,10 @@ include("callbacks.jl")
include("common_defaults.jl")
include("solve.jl")
include("internal_euler.jl")
include("init.jl")
include("forwarddiff.jl")
include("chainrules.jl")

include("termination_conditions.jl")

include("norecompile.jl")

# This is only used for oop stiff solvers
default_factorize(A) = lu(A; check = false)

Expand Down Expand Up @@ -181,8 +171,4 @@ export NLSolveTerminationMode,

export KeywordArgError, KeywordArgWarn, KeywordArgSilent

if !isdefined(Base, :get_extension)
include("../ext/DiffEqBaseDistributionsExt.jl")
end

end # module
Loading

0 comments on commit c9ad77c

Please sign in to comment.