diff --git a/ext/SparseDiffToolsZygoteExt.jl b/ext/SparseDiffToolsZygoteExt.jl index 49c3fbf0..1563c216 100644 --- a/ext/SparseDiffToolsZygoteExt.jl +++ b/ext/SparseDiffToolsZygoteExt.jl @@ -45,7 +45,8 @@ end ### Jac, Hes products -function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v)) where {F} +function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v), + cache3 = similar(v)) where {F} g = let f = f (dx, x) -> dx .= first(Zygote.gradient(f, x)) end diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index 604a10ce..dead4290 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -32,8 +32,8 @@ function auto_jacvec(f, x, v) vec(partials.(vec(f(y)), 1)) end -function num_jacvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v); - compute_f0 = true) +function num_jacvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v), + cache3 = similar(v); compute_f0 = true) vv = reshape(v, axes(x)) compute_f0 && (f(cache1, x)) T = eltype(x) @@ -134,7 +134,8 @@ function autonum_hesvec(f, x, v) partials.(g(Dual{DeivVecTag}.(x, v)), 1) end -function num_hesvecgrad!(dy, g, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v)) +function num_hesvecgrad!(dy, g, x, v, cache1 = similar(v), cache2 = similar(v), + cache3 = similar(v)) T = eltype(x) # Should it be min? max? mean? ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x))) diff --git a/src/differentiation/vecjac_products.jl b/src/differentiation/vecjac_products.jl index 7f827583..c48b34ca 100644 --- a/src/differentiation/vecjac_products.jl +++ b/src/differentiation/vecjac_products.jl @@ -1,5 +1,5 @@ -function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v); - compute_f0 = true) where {F} +function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v), + cache3 = similar(x); compute_f0 = true) where {F} compute_f0 && (f(cache1, x)) T = eltype(x) # Should it be min? max? mean? @@ -22,10 +22,11 @@ function num_vecjac(f::F, x, v, f0 = nothing) where {F} # Should it be min? max? mean? ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x))) du = similar(x) - cache = copy(x) + cache = similar(x) + copyto!(cache, x) for i in 1:length(x) cache[i] += ϵ - f0 = f(x) + f0 = f(cache) cache[i] = x[i] du[i] = (((f0 .- _f0) ./ ϵ)' * vv)[1] end