Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Apr 22, 2024
1 parent dbd8a73 commit db9d14a
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 21 deletions.
2 changes: 1 addition & 1 deletion ext/SparseDiffToolsEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module SparseDiffToolsEnzymeExt

import ArrayInterface: fast_scalar_indexing
import SparseDiffTools: __f̂, __maybe_copy_x, __jacobian!, __gradient, __gradient!,
AutoSparse{<:AutoEnzyme}, __test_backend_loaded
__test_backend_loaded
# FIXME: For Enzyme we currently assume reverse mode
import ADTypes: AutoSparse, AutoEnzyme
using Enzyme
Expand Down
4 changes: 2 additions & 2 deletions ext/SparseDiffToolsPolyesterForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct PolyesterForwardDiffJacobianCache{CO, CA, J, FX, X} <:
x::X
end

function sparse_jacobian_cache(
function sparse_jacobian_cache_aux(::ADTypes.ForwardMode,
ad::Union{AutoSparse{<:AutoPolyesterForwardDiff}, AutoPolyesterForwardDiff},
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
coloring_result = sd(ad, f, x)
Expand All @@ -35,7 +35,7 @@ function sparse_jacobian_cache(
return PolyesterForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
end

function sparse_jacobian_cache(
function sparse_jacobian_cache_aux(::ADTypes.ForwardMode,
ad::Union{AutoSparse{<:AutoPolyesterForwardDiff}, AutoPolyesterForwardDiff},
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
coloring_result = sd(ad, f!, fx, x)
Expand Down
11 changes: 7 additions & 4 deletions ext/SparseDiffToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import SparseDiffTools: numback_hesvec!,
numback_hesvec, autoback_hesvec!, autoback_hesvec, auto_vecjac!,
auto_vecjac
import SparseDiffTools: __f̂, __jacobian!, __gradient, __gradient!
import ADTypes: AutoZygote, AutoSparse{<:AutoZygote}
import ADTypes: AutoZygote, AutoSparse

@inline __test_backend_loaded(::Union{AutoSparse{<:AutoZygote}, AutoZygote}) = nothing

Expand All @@ -21,15 +21,17 @@ function __gradient(::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f::F, x, cols
return vec(∂x)
end

function __gradient!(::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f!::F, fx, x, cols) where {F}
function __gradient!(
::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f!::F, fx, x, cols) where {F}
return error("Zygote.jl cannot differentiate in-place (mutating) functions.")
end

# Zygote doesn't provide a way to accumulate directly into `J`. So we modify the code from
# https://github.com/FluxML/Zygote.jl/blob/82c7a000bae7fb0999275e62cc53ddb61aed94c7/src/lib/grad.jl#L140-L157C4
import Zygote: _jvec, _eyelike, _gradcopy!

@views function __jacobian!(J::AbstractMatrix, ::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f::F,
@views function __jacobian!(
J::AbstractMatrix, ::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f::F,
x) where {F}
y, back = Zygote.pullback(_jvec f, x)
δ = _eyelike(y)
Expand All @@ -40,7 +42,8 @@ import Zygote: _jvec, _eyelike, _gradcopy!
return J
end

function __jacobian!(_, ::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f!::F, fx, x) where {F}
function __jacobian!(
_, ::Union{AutoSparse{<:AutoZygote}, AutoZygote}, f!::F, fx, x) where {F}
return error("Zygote.jl cannot differentiate in-place (mutating) functions.")
end

Expand Down
2 changes: 1 addition & 1 deletion src/highlevel/coloring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ end

# Prespecified Colorvecs
function (alg::PrecomputedJacobianColorvec)(ad::AutoSparse, args...; kwargs...)
colorvec = _get_colorvec(alg, ad)
colorvec = _get_colorvec(alg, mode(ad))
J = alg.jac_prototype
(nz_rows, nz_cols) = ArrayInterface.findstructralnz(J)
return MatrixColoringResult(colorvec, J, nz_rows, nz_cols)
Expand Down
4 changes: 3 additions & 1 deletion src/highlevel/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ If `fx` is not specified, it will be computed by calling `f(x)`.
A cache for computing the Jacobian of type `AbstractMaybeSparseJacobianCache`.
"""
function sparse_jacobian_cache end
function sparse_jacobian_cache(ad::AbstractADType, sd::AbstractSparsityDetection, args...)
return sparse_jacobian_cache_aux(mode(ad), ad, sd, args...)
end

function sparse_jacobian_static_array(ad, cache, f, x::SArray)
# Not the most performant fallback
Expand Down
6 changes: 4 additions & 2 deletions src/highlevel/finite_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ end

__getfield(c::FiniteDiffJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype

function sparse_jacobian_cache(fd::Union{AutoSparse{<:AutoFiniteDiff}, AutoFiniteDiff},
function sparse_jacobian_cache_aux(
::ForwardMode, fd::Union{AutoSparse{<:AutoFiniteDiff}, AutoFiniteDiff},
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
coloring_result = sd(fd, f, x)
fx = fx === nothing ? similar(f(x)) : fx
Expand All @@ -23,7 +24,8 @@ function sparse_jacobian_cache(fd::Union{AutoSparse{<:AutoFiniteDiff}, AutoFinit
return FiniteDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
end

function sparse_jacobian_cache(fd::Union{AutoSparse{<:AutoFiniteDiff}, AutoFiniteDiff},
function sparse_jacobian_cache_aux(
::ForwardMode, fd::Union{AutoSparse{<:AutoFiniteDiff}, AutoFiniteDiff},
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
coloring_result = sd(fd, f!, fx, x)
if coloring_result isa NoMatrixColoring
Expand Down
6 changes: 4 additions & 2 deletions src/highlevel/forward_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ __standard_tag(::Nothing, f::F, x) where {F} = ForwardDiff.Tag(f, eltype(x))
__standard_tag(tag::ForwardDiff.Tag, ::F, _) where {F} = tag
__standard_tag(tag, f::F, x) where {F} = ForwardDiff.Tag(f, eltype(x))

function sparse_jacobian_cache(ad::Union{AutoSparse{<:AutoForwardDiff}, AutoForwardDiff},
function sparse_jacobian_cache_aux(
::ForwardMode, ad::Union{AutoSparse{<:AutoForwardDiff}, AutoForwardDiff},
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
coloring_result = sd(ad, f, x)
fx = fx === nothing ? similar(f(x)) : fx
Expand All @@ -29,7 +30,8 @@ function sparse_jacobian_cache(ad::Union{AutoSparse{<:AutoForwardDiff}, AutoForw
return ForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
end

function sparse_jacobian_cache(ad::Union{AutoSparse{<:AutoForwardDiff}, AutoForwardDiff},
function sparse_jacobian_cache_aux(
::ForwardMode, ad::Union{AutoSparse{<:AutoForwardDiff}, AutoForwardDiff},
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
coloring_result = sd(ad, f!, fx, x)
tag = __standard_tag(ad.tag, f!, x)
Expand Down
4 changes: 2 additions & 2 deletions src/highlevel/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end

__getfield(c::ReverseModeJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype

function sparse_jacobian_cache(ad::Union{AutoEnzyme, ReverseMode},
function sparse_jacobian_cache_aux(::ReverseMode, ad::AbstractADType,
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
fx = fx === nothing ? similar(f(x)) : fx
coloring_result = sd(ad, f, x)
Expand All @@ -18,7 +18,7 @@ function sparse_jacobian_cache(ad::Union{AutoEnzyme, ReverseMode},
collect(1:length(fx)))
end

function sparse_jacobian_cache(ad::Union{AutoEnzyme, ReverseMode},
function sparse_jacobian_cache_aux(::ReverseMode, ad::AbstractADType,
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
coloring_result = sd(ad, f!, fx, x)
jac_prototype = __getfield(coloring_result, Val(:jacobian_sparsity))
Expand Down
14 changes: 8 additions & 6 deletions test/test_sparse_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
@info "Out of Place Function"

DIFFTYPES = [AutoSparse(AutoZygote()), AutoZygote(), AutoSparse(AutoForwardDiff()),
AutoForwardDiff(), AutoSparse{<:AutoForwardDiff}(; chunksize = 0),
AutoForwardDiff(; chunksize = 0), AutoSparse{<:AutoForwardDiff}(; chunksize = 4),
AutoForwardDiff(), AutoSparse(AutoForwardDiff(; chunksize = 0)),
AutoForwardDiff(; chunksize = 0), AutoSparse(AutoForwardDiff(; chunksize = 4)),
AutoForwardDiff(; chunksize = 4), AutoSparse(AutoFiniteDiff()), AutoFiniteDiff(),
AutoEnzyme(), AutoSparse(AutoEnzyme())]

if VERSION v"1.9"
append!(DIFFTYPES,
[AutoSparse(AutoPolyesterForwardDiff()), AutoPolyesterForwardDiff(),
AutoSparse{<:AutoPolyesterForwardDiff}(; chunksize = 0),
AutoSparse(AutoPolyesterForwardDiff(; chunksize = 0)),
AutoPolyesterForwardDiff(; chunksize = 0),
AutoSparse{<:AutoPolyesterForwardDiff}(; chunksize = 4),
AutoSparse(AutoPolyesterForwardDiff(; chunksize = 4)),
AutoPolyesterForwardDiff(; chunksize = 4)])
end

Expand Down Expand Up @@ -124,7 +124,8 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
@testset "sparse_jacobian $(nameof(typeof(difftype))): In place" for difftype in (
AutoSparse(AutoForwardDiff()),
AutoForwardDiff(), AutoSparse{<:AutoForwardDiff}(; chunksize = 0),
AutoForwardDiff(; chunksize = 0), AutoSparse{<:AutoForwardDiff}(; chunksize = 4),
AutoForwardDiff(; chunksize = 0), AutoSparse{<:AutoForwardDiff}(;
chunksize = 4),
AutoForwardDiff(; chunksize = 4), AutoSparse(AutoFiniteDiff()), AutoFiniteDiff(),
AutoEnzyme(), AutoSparse(AutoEnzyme()))
y = similar(x)
Expand Down Expand Up @@ -211,7 +212,8 @@ end
end

@testset "Static Arrays" begin
@testset "No Allocations: $(difftype)" for difftype in (AutoSparse(AutoForwardDiff()),
@testset "No Allocations: $(difftype)" for difftype in (
AutoSparse(AutoForwardDiff()),
AutoForwardDiff())
J = __sparse_jacobian_no_allocs(difftype, NoSparsityDetection(), fvcat, x_sa)
@test J J_true_sa
Expand Down

0 comments on commit db9d14a

Please sign in to comment.