diff --git a/ext/OptimizationReverseDiffExt.jl b/ext/OptimizationReverseDiffExt.jl index aa6bb810c..f21829a45 100644 --- a/ext/OptimizationReverseDiffExt.jl +++ b/ext/OptimizationReverseDiffExt.jl @@ -54,10 +54,13 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff, if f.hv === nothing hv = function (H, θ, v, args...) - _θ = ForwardDiff.Dual.(θ, v) - res = similar(_θ) - grad(res, _θ, args...) - H .= getindex.(ForwardDiff.partials.(res), 1) + # _θ = ForwardDiff.Dual.(θ, v) + # res = similar(_θ) + # grad(res, _θ, args...) + # H .= getindex.(ForwardDiff.partials.(res), 1) + res = zeros(length(θ), length(θ)) + hess(res, θ, args...) + H .= res * v end else hv = f.hv @@ -171,10 +174,13 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, if f.hv === nothing hv = function (H, θ, v, args...) - _θ = ForwardDiff.Dual.(θ, v) - res = similar(_θ) - grad(res, _θ, args...) - H .= getindex.(ForwardDiff.partials.(res), 1) + # _θ = ForwardDiff.Dual.(θ, v) + # res = similar(_θ) + # grad(res, θ, args...) + # H .= getindex.(ForwardDiff.partials.(res), 1) + res = zeros(length(θ), length(θ)) + hess(res, θ, args...) + H .= res * v end else hv = f.hv diff --git a/ext/OptimizationSparseDiffExt.jl b/ext/OptimizationSparseDiffExt.jl index ea4146662..0337e34d2 100644 --- a/ext/OptimizationSparseDiffExt.jl +++ b/ext/OptimizationSparseDiffExt.jl @@ -537,10 +537,13 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff, if f.hv === nothing hv = function (H, θ, v, args...) - _θ = ForwardDiff.Dual.(θ, v) - res = similar(_θ) - grad(res, _θ, args...) - H .= getindex.(ForwardDiff.partials.(res), 1) + # _θ = ForwardDiff.Dual.(θ, v) + # res = similar(_θ) + # grad(res, _θ, args...) + # H .= getindex.(ForwardDiff.partials.(res), 1) + res = zeros(length(θ), length(θ)) + hess(res, θ, args...) + H .= res * v end else hv = f.hv @@ -671,10 +674,13 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, if f.hv === nothing hv = function (H, θ, v, args...) - _θ = ForwardDiff.Dual.(θ, v) - res = similar(_θ) - grad(res, _θ, args...) - H .= getindex.(ForwardDiff.partials.(res), 1) + # _θ = ForwardDiff.Dual.(θ, v) + # res = similar(_θ) + # grad(res, _θ, args...) + # H .= getindex.(ForwardDiff.partials.(res), 1) + res = zeros(length(θ), length(θ)) + hess(res, θ, args...) + H .= res * v end else hv = f.hv