From 6a54f64b83bea893cc4ece88e7eab9000aa5d4c9 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 18 Sep 2023 12:54:26 -0400 Subject: [PATCH] Bring back simple hvp for now --- ext/OptimizationReverseDiffExt.jl | 22 ++++++++++++++-------- ext/OptimizationSparseDiffExt.jl | 22 ++++++++++++++-------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/ext/OptimizationReverseDiffExt.jl b/ext/OptimizationReverseDiffExt.jl index aa6bb810c..f21829a45 100644 --- a/ext/OptimizationReverseDiffExt.jl +++ b/ext/OptimizationReverseDiffExt.jl @@ -54,10 +54,13 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff, if f.hv === nothing hv = function (H, θ, v, args...) - _θ = ForwardDiff.Dual.(θ, v) - res = similar(_θ) - grad(res, _θ, args...) - H .= getindex.(ForwardDiff.partials.(res), 1) + # _θ = ForwardDiff.Dual.(θ, v) + # res = similar(_θ) + # grad(res, _θ, args...) + # H .= getindex.(ForwardDiff.partials.(res), 1) + res = zeros(length(θ), length(θ)) + hess(res, θ, args...) + H .= res * v end else hv = f.hv @@ -171,10 +174,13 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, if f.hv === nothing hv = function (H, θ, v, args...) - _θ = ForwardDiff.Dual.(θ, v) - res = similar(_θ) - grad(res, _θ, args...) - H .= getindex.(ForwardDiff.partials.(res), 1) + # _θ = ForwardDiff.Dual.(θ, v) + # res = similar(_θ) + # grad(res, θ, args...) + # H .= getindex.(ForwardDiff.partials.(res), 1) + res = zeros(length(θ), length(θ)) + hess(res, θ, args...) + H .= res * v end else hv = f.hv diff --git a/ext/OptimizationSparseDiffExt.jl b/ext/OptimizationSparseDiffExt.jl index ea4146662..0337e34d2 100644 --- a/ext/OptimizationSparseDiffExt.jl +++ b/ext/OptimizationSparseDiffExt.jl @@ -537,10 +537,13 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff, if f.hv === nothing hv = function (H, θ, v, args...) - _θ = ForwardDiff.Dual.(θ, v) - res = similar(_θ) - grad(res, _θ, args...) - H .= getindex.(ForwardDiff.partials.(res), 1) + # _θ = ForwardDiff.Dual.(θ, v) + # res = similar(_θ) + # grad(res, _θ, args...) + # H .= getindex.(ForwardDiff.partials.(res), 1) + res = zeros(length(θ), length(θ)) + hess(res, θ, args...) + H .= res * v end else hv = f.hv @@ -671,10 +674,13 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, if f.hv === nothing hv = function (H, θ, v, args...) - _θ = ForwardDiff.Dual.(θ, v) - res = similar(_θ) - grad(res, _θ, args...) - H .= getindex.(ForwardDiff.partials.(res), 1) + # _θ = ForwardDiff.Dual.(θ, v) + # res = similar(_θ) + # grad(res, _θ, args...) + # H .= getindex.(ForwardDiff.partials.(res), 1) + res = zeros(length(θ), length(θ)) + hess(res, θ, args...) + H .= res * v end else hv = f.hv