Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Apr 8, 2023
1 parent f1e4f85 commit c831947
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 108 deletions.
25 changes: 16 additions & 9 deletions ext/SparseDiffToolsZygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ end

### Jac, Hes products

function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v),
cache2 = similar(v))
g = let f = f
(dx, x) -> dx .= first(Zygote.gradient(f, x))
end
Expand Down Expand Up @@ -42,14 +43,20 @@ function SparseDiffTools.numback_hesvec(f, x, v)
end

function SparseDiffTools.autoback_hesvec!(dy, f, x, v,
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
eltype(x), 1
}.(x,
ForwardDiff.Partials.(tuple.(reshape(v, size(x))))),
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
eltype(x), 1
}.(x,
ForwardDiff.Partials.(tuple.(reshape(v, size(x))))))
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),
eltype(x))),
eltype(x), 1
}.(x,
ForwardDiff.Partials.(tuple.(reshape(v,
size(x))))),
cache2 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),
eltype(x))),
eltype(x), 1
}.(x,
ForwardDiff.Partials.(tuple.(reshape(v,
size(x))))))
g = let f = f
(dx, x) -> dx .= first(Zygote.gradient(f, x))
end
Expand Down
2 changes: 1 addition & 1 deletion src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function auto_vecjac! end

@static if !isdefined(Base, :get_extension)
function __init__()
Requires.@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("../ext/SparseDiffToolsZygote.jl")
@reexport using .SparseDiffToolsZygote
end
Expand Down
4 changes: 2 additions & 2 deletions src/coloring/high_level.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ Note that if A isa SparseMatrixCSC, the sparsity pattern is defined by structura
ie includes explicitly stored zeros.
"""
function ArrayInterface.matrix_colors(A::AbstractMatrix,
alg::SparseDiffToolsColoringAlgorithm = GreedyD1Color();
partition_by_rows::Bool = false)
alg::SparseDiffToolsColoringAlgorithm = GreedyD1Color();
partition_by_rows::Bool = false)
_A = A isa SparseMatrixCSC ? A : sparse(A) # Avoid the copy
A_graph = matrix2graph(_A, partition_by_rows)
return color_graph(A_graph, alg)
Expand Down
2 changes: 1 addition & 1 deletion src/differentiation/compute_jacobian_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function ForwardColorJacCache(f::F, x, _chunksize = nothing;
_dx = similar(x)
else
tup = ArrayInterface.allowed_getindex(ArrayInterface.allowed_getindex(p, 1),
1) .* false
1) .* false
_pi = adapt(parameterless_type(dx), [tup for i in 1:length(dx)])
fx = reshape(Dual{T, eltype(dx), length(tup)}.(vec(dx), ForwardDiff.Partials.(_pi)),
size(dx)...)
Expand Down
33 changes: 14 additions & 19 deletions src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ end

### Operator Forms

struct FwdModeAutoDiffVecProd{F,U,C,V,V!} <: AbstractAutoDiffVecProd
struct FwdModeAutoDiffVecProd{F, U, C, V, V!} <: AbstractAutoDiffVecProd
f::F
u::U
cache::C
Expand Down Expand Up @@ -230,16 +230,15 @@ end

function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff(),
kwargs...)

cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
cache1 = similar(u)
cache2 = similar(u)

(cache1, cache2), num_jacvec, num_jacvec!
elseif autodiff isa AutoForwardDiff
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))

cache2 = copy(cache1)

Expand All @@ -249,7 +248,7 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
end

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u, u,)))
isinplace = static_hasmethod(f, typeof((u, u)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u), or f(du, u).")
Expand All @@ -260,13 +259,11 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
FunctionOperator(L, u, u;
isinplace = isinplace, outofplace = outofplace,
p = p, t = t, islinear = true,
kwargs...,
)
kwargs...)
end

function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff(),
kwargs...)

cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
cache1 = similar(u)
cache2 = similar(u)
Expand All @@ -283,8 +280,8 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
@assert static_hasmethod(autoback_hesvec, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"

cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))
cache2 = copy(cache1)

(cache1, cache2), autoback_hesvec, autoback_hesvec!
Expand All @@ -293,7 +290,7 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
end

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u,)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u).")
Expand All @@ -304,13 +301,12 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
FunctionOperator(L, u, u;
isinplace = isinplace, outofplace = outofplace,
p = p, t = t, islinear = true,
kwargs...,
)
kwargs...)
end

function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff(),
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing;
autodiff = AutoForwardDiff(),
kwargs...)

cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
cache1 = similar(u)
cache2 = similar(u)
Expand All @@ -319,7 +315,7 @@ function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = Au
elseif autodiff isa AutoForwardDiff
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))
}.(u, ForwardDiff.Partials.(tuple.(u)))
cache2 = copy(cache1)

(cache1, cache2), auto_hesvecgrad, auto_hesvecgrad!
Expand All @@ -328,7 +324,7 @@ function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = Au
end

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u, u,)))
isinplace = static_hasmethod(f, typeof((u, u)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u), or f(du, u).")
Expand All @@ -339,7 +335,6 @@ function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = Au
FunctionOperator(L, u, u;
isinplace = isinplace, outofplace = outofplace,
p = p, t = t, islinear = true,
kwargs...,
)
kwargs...)
end
#
20 changes: 8 additions & 12 deletions src/differentiation/vecjac_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end

### Operator Forms

struct RevModeAutoDiffVecProd{ad,iip,oop,F,U,C,V,V!} <: AbstractAutoDiffVecProd
struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffVecProd
f::F
u::U
cache::C
Expand All @@ -57,10 +57,8 @@ struct RevModeAutoDiffVecProd{ad,iip,oop,F,U,C,V,V!} <: AbstractAutoDiffVecProd
typeof(u),
typeof(cache),
typeof(vecprod),
typeof(vecprod!),
}(
f, u, cache, vecprod, vecprod!,
)
typeof(vecprod!)
}(f, u, cache, vecprod, vecprod!)
end
end

Expand All @@ -81,17 +79,16 @@ function (L::RevModeAutoDiffVecProd)(v, p, t)
end

# prefer non in-place method
function (L::RevModeAutoDiffVecProd{ad,iip,true})(dv, v, p, t) where{ad,iip}
function (L::RevModeAutoDiffVecProd{ad, iip, true})(dv, v, p, t) where {ad, iip}
L.vecprod!(dv, L.f, L.u, v, L.cache...)
end

function (L::RevModeAutoDiffVecProd{ad,true,false})(dv, v, p, t) where{ad}
function (L::RevModeAutoDiffVecProd{ad, true, false})(dv, v, p, t) where {ad}
L.vecprod!(dv, L.f, L.u, v, L.cache...)
end

function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFiniteDiff(),
kwargs...)

vecprod, vecprod! = if autodiff isa AutoFiniteDiff
num_vecjac, num_vecjac!
elseif autodiff isa AutoZygote
Expand All @@ -100,10 +97,10 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFi
auto_vecjac, auto_vecjac!
end

cache = (similar(u), similar(u),)
cache = (similar(u), similar(u))

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u, u,)))
isinplace = static_hasmethod(f, typeof((u, u)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u), or f(du, u)")
Expand All @@ -115,7 +112,6 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFi
FunctionOperator(L, u, u;
isinplace = isinplace, outofplace = outofplace,
p = p, t = t, islinear = true,
kwargs...
)
kwargs...)
end
#
Loading

0 comments on commit c831947

Please sign in to comment.