From 88aaa61a0b949c6dea0eabc0941b25c32c71053e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 22 Apr 2024 14:00:24 +0200 Subject: [PATCH] Fix more dispatches --- src/SparseDiffTools.jl | 1 + src/highlevel/common.jl | 4 +- src/highlevel/forward_or_reverse_mode.jl | 21 +++++++++++ test/test_sparse_jacobian.jl | 48 ++++++++++++------------ 4 files changed, 48 insertions(+), 26 deletions(-) create mode 100644 src/highlevel/forward_or_reverse_mode.jl diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index 024c807d..c7d3d979 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -54,6 +54,7 @@ include("highlevel/common.jl") include("highlevel/coloring.jl") include("highlevel/forward_mode.jl") include("highlevel/reverse_mode.jl") +include("highlevel/forward_or_reverse_mode.jl") include("highlevel/finite_diff.jl") Base.@pure __parameterless_type(T) = Base.typename(T).wrapper diff --git a/src/highlevel/common.jl b/src/highlevel/common.jl index 4173422c..dbc8bfcd 100644 --- a/src/highlevel/common.jl +++ b/src/highlevel/common.jl @@ -171,12 +171,12 @@ If `fx` is not specified, it will be computed by calling `f(x)`. A cache for computing the Jacobian of type `AbstractMaybeSparseJacobianCache`. """ function sparse_jacobian_cache( - ad::AbstractADType, sd::AbstractSparsityDetection, f, x; fx = nothing) + ad::AbstractADType, sd::AbstractMaybeSparsityDetection, f, x; fx = nothing) return sparse_jacobian_cache_aux(mode(ad), ad, sd, f, x; fx) end function sparse_jacobian_cache( - ad::AbstractADType, sd::AbstractSparsityDetection, f!, x, fx) + ad::AbstractADType, sd::AbstractMaybeSparsityDetection, f!, x, fx) return sparse_jacobian_cache_aux(mode(ad), ad, sd, f!, x, fx) end diff --git a/src/highlevel/forward_or_reverse_mode.jl b/src/highlevel/forward_or_reverse_mode.jl new file mode 100644 index 00000000..2f06d84c --- /dev/null +++ b/src/highlevel/forward_or_reverse_mode.jl @@ -0,0 +1,21 @@ +function sparse_jacobian_cache_aux(::ForwardOrReverseMode, ad::AbstractADType, + sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F} + if ad isa AutoEnzyme + return sparse_jacobian_cache_aux(ReverseMode(), ad, sd, f, x; fx) + elseif ad isa AutoDiffractor + return sparse_jacobian_cache_aux(ForwardMode(), ad, sd, f, x; fx) + else + error("Unknown mixed mode AD") + end +end + +function sparse_jacobian_cache_aux(::ForwardOrReverseMode, ad::AbstractADType, + sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F} + if ad isa AutoEnzyme + return sparse_jacobian_cache_aux(ReverseMode(), ad, sd, f!, fx, x) + elseif ad isa AutoDiffractor + return sparse_jacobian_cache_aux(ForwardMode(), ad, sd, f!, fx, x) + else + error("Unknown mixed mode AD") + end +end diff --git a/test/test_sparse_jacobian.jl b/test/test_sparse_jacobian.jl index 07992b9e..95f3aa7e 100644 --- a/test/test_sparse_jacobian.jl +++ b/test/test_sparse_jacobian.jl @@ -1,19 +1,20 @@ ## Sparse Jacobian tests -using SparseDiffTools, - Symbolics, ForwardDiff, LinearAlgebra, SparseArrays, Zygote, Enzyme, Test, +using ADTypes, SparseDiffTools, + Symbolics, ForwardDiff, PolyesterForwardDiff, LinearAlgebra, SparseArrays, Zygote, + Enzyme, Test, StaticArrays +using ADTypes: dense_ad -@static if VERSION ≥ v"1.9" - using PolyesterForwardDiff -end - -function __chunksize(::Union{AutoSparse{<:AutoForwardDiff}{C}, AutoForwardDiff{C}, - AutoSparse{<:AutoPolyesterForwardDiff}{C}, AutoPolyesterForwardDiff{C}}) where {C} +function __chunksize(::Union{ + AutoSparse{<:AutoForwardDiff{C}}, AutoForwardDiff{C}, + AutoSparse{<:AutoPolyesterForwardDiff{C}}, AutoPolyesterForwardDiff{C} +}) where {C} return C end function __isinferrable(difftype) - return !(difftype isa AutoSparse{<:AutoForwardDiff} || difftype isa AutoForwardDiff || + return !(difftype isa AutoSparse{<:AutoForwardDiff} || + difftype isa AutoForwardDiff || difftype isa AutoSparse{<:AutoPolyesterForwardDiff} || difftype isa AutoPolyesterForwardDiff) || (__chunksize(difftype) isa Int && __chunksize(difftype) > 0) @@ -51,24 +52,23 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa PrecomputedJacobianColorvec(; jac_prototype = J_sparsity, row_colorvec, col_colorvec)] @testset "High-Level API" begin - @testset "Sparsity Detection: $(nameof(typeof(sd)))" for sd in SPARSITY_DETECTION_ALGS + @testset "Sparsity Detection: $(nameof(typeof(sd))) - $(isa(ad, AutoSparse) ? $(nameof(typeof(dense_ad(ad)))) : "")" for sd in SPARSITY_DETECTION_ALGS @info "Sparsity Detection: $(nameof(typeof(sd)))" @info "Out of Place Function" - DIFFTYPES = [AutoSparse(AutoZygote()), AutoZygote(), AutoSparse(AutoForwardDiff()), - AutoForwardDiff(), AutoSparse(AutoForwardDiff(; chunksize = 0)), - AutoForwardDiff(; chunksize = 0), AutoSparse(AutoForwardDiff(; chunksize = 4)), - AutoForwardDiff(; chunksize = 4), AutoSparse(AutoFiniteDiff()), AutoFiniteDiff(), - AutoEnzyme(), AutoSparse(AutoEnzyme())] - - if VERSION ≥ v"1.9" - append!(DIFFTYPES, - [AutoSparse(AutoPolyesterForwardDiff()), AutoPolyesterForwardDiff(), - AutoSparse(AutoPolyesterForwardDiff(; chunksize = 0)), - AutoPolyesterForwardDiff(; chunksize = 0), - AutoSparse(AutoPolyesterForwardDiff(; chunksize = 4)), - AutoPolyesterForwardDiff(; chunksize = 4)]) - end + DIFFTYPES = [ + AutoSparse(AutoZygote()), AutoZygote(), + AutoSparse(AutoForwardDiff()), AutoForwardDiff(), + AutoSparse(AutoForwardDiff(; chunksize = 0)), AutoForwardDiff(; chunksize = 0), + AutoSparse(AutoForwardDiff(; chunksize = 4)), AutoForwardDiff(; chunksize = 4), + AutoSparse(AutoFiniteDiff()), AutoFiniteDiff(), + AutoEnzyme(), AutoSparse(AutoEnzyme()), + AutoSparse(AutoPolyesterForwardDiff()), AutoPolyesterForwardDiff(), + AutoSparse(AutoPolyesterForwardDiff(; chunksize = 0)), + AutoPolyesterForwardDiff(; chunksize = 0), + AutoSparse(AutoPolyesterForwardDiff(; chunksize = 4)), + AutoPolyesterForwardDiff(; chunksize = 4) + ] @testset "sparse_jacobian $(nameof(typeof(difftype))): Out of Place" for difftype in DIFFTYPES @testset "Cache & Reuse" begin