From 55a17f6f00a93c6ec80d5ba454c5cb20d37d6097 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 7 Oct 2023 10:09:11 +0200 Subject: [PATCH] Move hasbranching to FunctionProperties.jl --- Project.toml | 6 ++-- src/SciMLSensitivity.jl | 5 +-- src/hasbranching.jl | 80 ----------------------------------------- test/hasbranching.jl | 9 ----- test/runtests.jl | 3 -- 5 files changed, 3 insertions(+), 100 deletions(-) delete mode 100644 src/hasbranching.jl delete mode 100644 test/hasbranching.jl diff --git a/Project.toml b/Project.toml index 1fafc6888..bb8147156 100644 --- a/Project.toml +++ b/Project.toml @@ -7,17 +7,16 @@ version = "7.41.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -Cassette = "7057c7e9-c182-5462-911a-8362d720325c" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" -DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -48,17 +47,16 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" ADTypes = "0.1, 0.2" Adapt = "1.0, 2.0, 3.0" ArrayInterface = "7" -Cassette = "0.3.6" ChainRulesCore = "0.10.7, 1" DiffEqBase = "6.93" DiffEqCallbacks = "2.29" DiffEqNoiseProcess = "4.1.4, 5.0" -DiffRules = "1" Distributions = "0.24, 0.25" EllipsisNotation = "1" Enzyme = "0.11.6" FiniteDiff = "2" ForwardDiff = "0.10" +FunctionProperties = "0.1" FunctionWrappersWrappers = "0.1" GPUArraysCore = "0.1" LinearSolve = "2" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 3c4ffbc22..062132127 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -21,10 +21,8 @@ import TruncatedStacktraces import PreallocationTools: dualcache, get_tmp, DiffCache, LazyBufferCache import FunctionWrappersWrappers - -using Cassette, DiffRules -using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot using EllipsisNotation +using FunctionProperties: hasbranching using Markdown @@ -41,7 +39,6 @@ import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAl AbstractSecondOrderSensitivityAlgorithm, AbstractShadowingSensitivityAlgorithm -include("hasbranching.jl") include("sensitivity_algorithms.jl") include("derivative_wrappers.jl") include("sensitivity_interface.jl") diff --git a/src/hasbranching.jl b/src/hasbranching.jl deleted file mode 100644 index c311bbf8e..000000000 --- a/src/hasbranching.jl +++ /dev/null @@ -1,80 +0,0 @@ -const printbranch = false - -Cassette.@context HasBranchingCtx - -function Cassette.overdub(ctx::HasBranchingCtx, f, args...) - if Cassette.canrecurse(ctx, f, args...) - return Cassette.recurse(ctx, f, args...) - else - return Cassette.fallback(ctx, f, args...) - end -end - -for (mod, f, n) in DiffRules.diffrules(; filter_modules = nothing) - if !(isdefined(@__MODULE__, mod) && isdefined(getfield(@__MODULE__, mod), f)) - continue # Skip rules for methods not defined in the current scope - end - @eval function Cassette.overdub(::HasBranchingCtx, f::Core.Typeof($mod.$f), - x::Vararg{Any, $n}) - f(x...) - end -end - -function _pass(::Type{<:HasBranchingCtx}, reflection::Cassette.Reflection) - ir = reflection.code_info - - if any(x -> isa(x, GotoIfNot), ir.code) - printbranch && println("GotoIfNot detected in $(reflection.method)\nir = $ir\n") - Cassette.insert_statements!(ir.code, ir.codelocs, - (stmt, i) -> i == 1 ? 3 : nothing, - (stmt, i) -> Any[Expr(:call, - Expr(:nooverdub, - GlobalRef(Base, :getfield)), - Expr(:contextslot), - QuoteNode(:metadata)), - Expr(:call, - Expr(:nooverdub, - GlobalRef(Base, :setindex!)), - SSAValue(1), true, - QuoteNode(:has_branching)), - stmt]) - Cassette.insert_statements!(ir.code, ir.codelocs, - (stmt, i) -> i > 2 && isa(stmt, Expr) ? 1 : nothing, - (stmt, i) -> begin - callstmt = Meta.isexpr(stmt, :(=)) ? stmt.args[2] : - stmt - Meta.isexpr(stmt, :call) || - Meta.isexpr(stmt, :invoke) || return Any[stmt] - callstmt = Expr(callstmt.head, - Expr(:nooverdub, callstmt.args[1]), - callstmt.args[2:end]...) - return Any[Meta.isexpr(stmt, :(=)) ? - Expr(:(=), stmt.args[1], callstmt) : - callstmt] - end) - end - return ir -end - -const pass = Cassette.@pass _pass - -function hasbranching(f, x...) - metadata = Dict(:has_branching => false) - Cassette.overdub(Cassette.disablehooks(HasBranchingCtx(; pass, metadata)), f, x...) - return metadata[:has_branching] -end - -Cassette.overdub(::HasBranchingCtx, ::typeof(+), x...) = +(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(*), x...) = *(x...) -function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.materialize), x...) - Base.materialize(x...) -end -function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.literal_pow), x...) - Base.literal_pow(x...) -end -Cassette.overdub(::HasBranchingCtx, ::typeof(Base.getindex), x...) = Base.getindex(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(Core.Typeof), x...) = Core.Typeof(x...) -function Cassette.overdub(::HasBranchingCtx, ::Type{Base.OneTo{T}}, - stop) where {T <: Integer} - Base.OneTo{T}(stop) -end diff --git a/test/hasbranching.jl b/test/hasbranching.jl deleted file mode 100644 index ec65955ee..000000000 --- a/test/hasbranching.jl +++ /dev/null @@ -1,9 +0,0 @@ -using SciMLSensitivity, Test - -@test SciMLSensitivity.hasbranching(1, 2) do x, y - (x < 0 ? -x : x) + exp(y) -end - -@test !SciMLSensitivity.hasbranching(1, 2) do x, y - ifelse(x < 0, -x, x) + exp(y) -end diff --git a/test/runtests.jl b/test/runtests.jl index 1f3ba82af..d49d928a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,9 +53,6 @@ end end if GROUP == "All" || GROUP == "Core2" - @time @safetestset "hasbranching" begin - include("hasbranching.jl") - end @time @safetestset "Literal Adjoint" begin include("literal_adjoint.jl") end