Skip to content

Commit

Permalink
Merge pull request #241 from vpuri3/tag
Browse files Browse the repository at this point in the history
add `tag` kwarg to JacVec, HesVec, HesVecGrad
  • Loading branch information
ChrisRackauckas authored May 15, 2023
2 parents c831947 + e77e83f commit 7cb7918
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 11 deletions.
2 changes: 1 addition & 1 deletion ext/SparseDiffToolsZygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ end

function SparseDiffTools.autoback_hesvec(f, x, v)
g = x -> first(Zygote.gradient(f, x))
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
ForwardDiff.partials.(g(y), 1)
end
Expand Down
17 changes: 8 additions & 9 deletions src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,16 @@ function (L::FwdModeAutoDiffVecProd)(dv, v, p, t)
L.vecprod!(dv, L.f, L.u, v, L.cache...)
end

function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff(),
kwargs...)
function JacVec(f, u::AbstractArray, p = nothing, t = nothing;
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
cache1 = similar(u)
cache2 = similar(u)

(cache1, cache2), num_jacvec, num_jacvec!
elseif autodiff isa AutoForwardDiff
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))

cache2 = copy(cache1)
Expand All @@ -262,8 +262,8 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
kwargs...)
end

function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff(),
kwargs...)
function HesVec(f, u::AbstractArray, p = nothing, t = nothing;
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
cache1 = similar(u)
cache2 = similar(u)
Expand All @@ -280,7 +280,7 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
@assert static_hasmethod(autoback_hesvec, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"

cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))
cache2 = copy(cache1)

Expand All @@ -305,16 +305,15 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
end

function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing;
autodiff = AutoForwardDiff(),
kwargs...)
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
cache1 = similar(u)
cache2 = similar(u)

(cache1, cache2), num_hesvecgrad, num_hesvecgrad!
elseif autodiff isa AutoForwardDiff
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))
cache2 = copy(cache1)

Expand Down
35 changes: 34 additions & 1 deletion test/test_jaches_products.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using SparseDiffTools, ForwardDiff, FiniteDiff, Zygote, IterativeSolvers
using LinearAlgebra, Test
using SparseDiffTools: get_tag, DeivVecTag

using Random
Random.seed!(123)
N = 300

struct MyTag end

N = 300
x = rand(N)
v = rand(N)

Expand Down Expand Up @@ -104,6 +107,10 @@ _dy = copy(dy);
update_coefficients!(f, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) auto_jacvec(f, v, v)

# GMRES test
out = similar(v)
@test_nowarn gmres!(out, L, v)

L = JacVec(f, copy(x), 1.0, 1.0; autodiff = AutoFiniteDiff())
update_coefficients!(f, x, 1.0, 1.0)
@test L * x num_jacvec(f, x, x)
Expand All @@ -121,9 +128,16 @@ _dy = copy(dy);
update_coefficients!(f, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0)num_jacvec(f, v, v) rtol=1e-6

# GMRES test
out = similar(v)
@test_nowarn gmres!(out, L, v)

# Tag test
L = JacVec(f, copy(x), 1.0, 1.0)
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{DeivVecTag, eltype(x)}
L = JacVec(f, copy(x), 1.0, 1.0; tag = MyTag())
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{MyTag, eltype(x)}

@info "HesVec"

L = HesVec(g, copy(x), 1.0, 1.0, autodiff = AutoFiniteDiff())
Expand Down Expand Up @@ -159,6 +173,7 @@ _dy = copy(dy);
update_coefficients!(g, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) numauto_hesvec(g, v, v)

# GMRES test
out = similar(v)
gmres!(out, L, v)

Expand All @@ -179,9 +194,16 @@ _dy = copy(dy);
update_coefficients!(g, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) autoback_hesvec(g, v, v)

# GMRES test
out = similar(v)
gmres!(out, L, v)

# Tag test
L = HesVec(g, copy(x), 1.0, 1.0; autodiff = AutoZygote())
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{DeivVecTag, eltype(x)}
L = HesVec(g, copy(x), 1.0, 1.0; autodiff = AutoZygote(), tag = MyTag())
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{MyTag, eltype(x)}

@info "HesVecGrad"

L = HesVecGrad(h, copy(x), 1.0, 1.0; autodiff = AutoFiniteDiff())
Expand All @@ -203,6 +225,10 @@ _dy = copy(dy);
update_coefficients!(g, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0)num_hesvec(g, v, v) rtol=1e-2

# GMRES test
out = similar(v)
gmres!(out, L, v)

L = HesVecGrad(h, copy(x), 1.0, 1.0)
update_coefficients!(g, x, 1.0, 1.0)
update_coefficients!(h, x, 1.0, 1.0)
Expand All @@ -223,6 +249,7 @@ update_coefficients!(g, v, 5.0, 6.0)
update_coefficients!(h, v, 5.0, 6.0)
@test L(dy, v, 5.0, 6.0) numauto_hesvec(g, v, v)

# GMRES test
out = similar(v)
gmres!(out, L, v)

Expand All @@ -231,4 +258,10 @@ gmres!(out, L, v)
@test x x0
@test v v0

# Tag test
L = HesVecGrad(g, copy(x), 1.0, 1.0)
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{DeivVecTag, eltype(x)}
L = HesVecGrad(g, copy(x), 1.0, 1.0; tag = MyTag())
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{MyTag, eltype(x)}

#

0 comments on commit 7cb7918

Please sign in to comment.