Skip to content

Commit

Permalink
Fix num_vecjac
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 16, 2024
1 parent 1156e34 commit 4ad919a
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
version = "2.15.0"
version = "2.15.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion ext/SparseDiffToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ end

### Jac, Hes products

function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v)) where {F}
function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v),
cache3 = similar(v)) where {F}
g = let f = f
(dx, x) -> dx .= first(Zygote.gradient(f, x))
end
Expand Down
7 changes: 4 additions & 3 deletions src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ function auto_jacvec(f, x, v)
vec(partials.(vec(f(y)), 1))
end

function num_jacvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v);
compute_f0 = true)
function num_jacvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v),
cache3 = similar(v); compute_f0 = true)
vv = reshape(v, axes(x))
compute_f0 && (f(cache1, x))
T = eltype(x)
Expand Down Expand Up @@ -134,7 +134,8 @@ function autonum_hesvec(f, x, v)
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
end

function num_hesvecgrad!(dy, g, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v))
function num_hesvecgrad!(dy, g, x, v, cache1 = similar(v), cache2 = similar(v),
cache3 = similar(v))
T = eltype(x)
# Should it be min? max? mean?
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
Expand Down
11 changes: 6 additions & 5 deletions src/differentiation/vecjac_products.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v);
compute_f0 = true) where {F}
function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v),
cache3 = similar(x); compute_f0 = true) where {F}
compute_f0 && (f(cache1, x))
T = eltype(x)
# Should it be min? max? mean?
Expand All @@ -22,10 +22,11 @@ function num_vecjac(f::F, x, v, f0 = nothing) where {F}
# Should it be min? max? mean?
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
du = similar(x)
cache = copy(x)
cache = similar(x)
copyto!(cache, x)
for i in 1:length(x)
cache[i] += ϵ
f0 = f(x)
f0 = f(cache)
cache[i] = x[i]
du[i] = (((f0 .- _f0) ./ ϵ)' * vv)[1]
end
Expand Down Expand Up @@ -93,7 +94,7 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
end

function _vecjac(f::F, fu, u, autodiff::AutoFiniteDiff) where {F}
cache = (similar(fu), similar(fu), similar(fu))
cache = (similar(fu), similar(fu), similar(u))
pullback = nothing
return AutoDiffVJP(f, u, cache, autodiff, pullback)
end
Expand Down

0 comments on commit 4ad919a

Please sign in to comment.