Skip to content

Commit

Permalink
Fix more dispatches
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Apr 22, 2024
1 parent 20cc8e7 commit 88aaa61
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 26 deletions.
1 change: 1 addition & 0 deletions src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ include("highlevel/common.jl")
include("highlevel/coloring.jl")
include("highlevel/forward_mode.jl")
include("highlevel/reverse_mode.jl")
include("highlevel/forward_or_reverse_mode.jl")
include("highlevel/finite_diff.jl")

Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
Expand Down
4 changes: 2 additions & 2 deletions src/highlevel/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ If `fx` is not specified, it will be computed by calling `f(x)`.
A cache for computing the Jacobian of type `AbstractMaybeSparseJacobianCache`.
"""
function sparse_jacobian_cache(
ad::AbstractADType, sd::AbstractSparsityDetection, f, x; fx = nothing)
ad::AbstractADType, sd::AbstractMaybeSparsityDetection, f, x; fx = nothing)
return sparse_jacobian_cache_aux(mode(ad), ad, sd, f, x; fx)
end

function sparse_jacobian_cache(
ad::AbstractADType, sd::AbstractSparsityDetection, f!, x, fx)
ad::AbstractADType, sd::AbstractMaybeSparsityDetection, f!, x, fx)
return sparse_jacobian_cache_aux(mode(ad), ad, sd, f!, x, fx)
end

Expand Down
21 changes: 21 additions & 0 deletions src/highlevel/forward_or_reverse_mode.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
function sparse_jacobian_cache_aux(::ForwardOrReverseMode, ad::AbstractADType,
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
if ad isa AutoEnzyme
return sparse_jacobian_cache_aux(ReverseMode(), ad, sd, f, x; fx)
elseif ad isa AutoDiffractor
return sparse_jacobian_cache_aux(ForwardMode(), ad, sd, f, x; fx)
else
error("Unknown mixed mode AD")
end
end

function sparse_jacobian_cache_aux(::ForwardOrReverseMode, ad::AbstractADType,
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
if ad isa AutoEnzyme
return sparse_jacobian_cache_aux(ReverseMode(), ad, sd, f!, fx, x)
elseif ad isa AutoDiffractor
return sparse_jacobian_cache_aux(ForwardMode(), ad, sd, f!, fx, x)
else
error("Unknown mixed mode AD")
end
end
48 changes: 24 additions & 24 deletions test/test_sparse_jacobian.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
## Sparse Jacobian tests
using SparseDiffTools,
Symbolics, ForwardDiff, LinearAlgebra, SparseArrays, Zygote, Enzyme, Test,
using ADTypes, SparseDiffTools,
Symbolics, ForwardDiff, PolyesterForwardDiff, LinearAlgebra, SparseArrays, Zygote,
Enzyme, Test,
StaticArrays
using ADTypes: dense_ad

@static if VERSION v"1.9"
using PolyesterForwardDiff
end

function __chunksize(::Union{AutoSparse{<:AutoForwardDiff}{C}, AutoForwardDiff{C},
AutoSparse{<:AutoPolyesterForwardDiff}{C}, AutoPolyesterForwardDiff{C}}) where {C}
function __chunksize(::Union{
AutoSparse{<:AutoForwardDiff{C}}, AutoForwardDiff{C},
AutoSparse{<:AutoPolyesterForwardDiff{C}}, AutoPolyesterForwardDiff{C}
}) where {C}
return C
end

function __isinferrable(difftype)
return !(difftype isa AutoSparse{<:AutoForwardDiff} || difftype isa AutoForwardDiff ||
return !(difftype isa AutoSparse{<:AutoForwardDiff} ||
difftype isa AutoForwardDiff ||
difftype isa AutoSparse{<:AutoPolyesterForwardDiff} ||
difftype isa AutoPolyesterForwardDiff) ||
(__chunksize(difftype) isa Int && __chunksize(difftype) > 0)
Expand Down Expand Up @@ -51,24 +52,23 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
PrecomputedJacobianColorvec(; jac_prototype = J_sparsity, row_colorvec, col_colorvec)]

@testset "High-Level API" begin
@testset "Sparsity Detection: $(nameof(typeof(sd)))" for sd in SPARSITY_DETECTION_ALGS
@testset "Sparsity Detection: $(nameof(typeof(sd))) - $(isa(ad, AutoSparse) ? $(nameof(typeof(dense_ad(ad)))) : "")" for sd in SPARSITY_DETECTION_ALGS
@info "Sparsity Detection: $(nameof(typeof(sd)))"
@info "Out of Place Function"

DIFFTYPES = [AutoSparse(AutoZygote()), AutoZygote(), AutoSparse(AutoForwardDiff()),
AutoForwardDiff(), AutoSparse(AutoForwardDiff(; chunksize = 0)),
AutoForwardDiff(; chunksize = 0), AutoSparse(AutoForwardDiff(; chunksize = 4)),
AutoForwardDiff(; chunksize = 4), AutoSparse(AutoFiniteDiff()), AutoFiniteDiff(),
AutoEnzyme(), AutoSparse(AutoEnzyme())]

if VERSION v"1.9"
append!(DIFFTYPES,
[AutoSparse(AutoPolyesterForwardDiff()), AutoPolyesterForwardDiff(),
AutoSparse(AutoPolyesterForwardDiff(; chunksize = 0)),
AutoPolyesterForwardDiff(; chunksize = 0),
AutoSparse(AutoPolyesterForwardDiff(; chunksize = 4)),
AutoPolyesterForwardDiff(; chunksize = 4)])
end
DIFFTYPES = [
AutoSparse(AutoZygote()), AutoZygote(),
AutoSparse(AutoForwardDiff()), AutoForwardDiff(),
AutoSparse(AutoForwardDiff(; chunksize = 0)), AutoForwardDiff(; chunksize = 0),
AutoSparse(AutoForwardDiff(; chunksize = 4)), AutoForwardDiff(; chunksize = 4),
AutoSparse(AutoFiniteDiff()), AutoFiniteDiff(),
AutoEnzyme(), AutoSparse(AutoEnzyme()),
AutoSparse(AutoPolyesterForwardDiff()), AutoPolyesterForwardDiff(),
AutoSparse(AutoPolyesterForwardDiff(; chunksize = 0)),
AutoPolyesterForwardDiff(; chunksize = 0),
AutoSparse(AutoPolyesterForwardDiff(; chunksize = 4)),
AutoPolyesterForwardDiff(; chunksize = 4)
]

@testset "sparse_jacobian $(nameof(typeof(difftype))): Out of Place" for difftype in DIFFTYPES
@testset "Cache & Reuse" begin
Expand Down

0 comments on commit 88aaa61

Please sign in to comment.