Skip to content

Commit

Permalink
Bring back simple hvp for now
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 18, 2023
1 parent 5e92e40 commit 6a54f64
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
22 changes: 14 additions & 8 deletions ext/OptimizationReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,

if f.hv === nothing
hv = function (H, θ, v, args...)
= ForwardDiff.Dual.(θ, v)
res = similar(_θ)
grad(res, _θ, args...)
H .= getindex.(ForwardDiff.partials.(res), 1)
# _θ = ForwardDiff.Dual.(θ, v)
# res = similar(_θ)
# grad(res, _θ, args...)
# H .= getindex.(ForwardDiff.partials.(res), 1)
res = zeros(length(θ), length(θ))
hess(res, θ, args...)
H .= res * v

Check warning on line 63 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L61-L63

Added lines #L61 - L63 were not covered by tests
end
else
hv = f.hv
Expand Down Expand Up @@ -171,10 +174,13 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,

if f.hv === nothing
hv = function (H, θ, v, args...)
= ForwardDiff.Dual.(θ, v)
res = similar(_θ)
grad(res, _θ, args...)
H .= getindex.(ForwardDiff.partials.(res), 1)
# _θ = ForwardDiff.Dual.(θ, v)
# res = similar(_θ)
# grad(res, θ, args...)
# H .= getindex.(ForwardDiff.partials.(res), 1)
res = zeros(length(θ), length(θ))
hess(res, θ, args...)
H .= res * v

Check warning on line 183 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L181-L183

Added lines #L181 - L183 were not covered by tests
end
else
hv = f.hv
Expand Down
22 changes: 14 additions & 8 deletions ext/OptimizationSparseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,13 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,

if f.hv === nothing
hv = function (H, θ, v, args...)
= ForwardDiff.Dual.(θ, v)
res = similar(_θ)
grad(res, _θ, args...)
H .= getindex.(ForwardDiff.partials.(res), 1)
# _θ = ForwardDiff.Dual.(θ, v)
# res = similar(_θ)
# grad(res, _θ, args...)
# H .= getindex.(ForwardDiff.partials.(res), 1)
res = zeros(length(θ), length(θ))
hess(res, θ, args...)
H .= res * v

Check warning on line 546 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L544-L546

Added lines #L544 - L546 were not covered by tests
end
else
hv = f.hv
Expand Down Expand Up @@ -671,10 +674,13 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,

if f.hv === nothing
hv = function (H, θ, v, args...)
= ForwardDiff.Dual.(θ, v)
res = similar(_θ)
grad(res, _θ, args...)
H .= getindex.(ForwardDiff.partials.(res), 1)
# _θ = ForwardDiff.Dual.(θ, v)
# res = similar(_θ)
# grad(res, _θ, args...)
# H .= getindex.(ForwardDiff.partials.(res), 1)
res = zeros(length(θ), length(θ))
hess(res, θ, args...)
H .= res * v

Check warning on line 683 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L681-L683

Added lines #L681 - L683 were not covered by tests
end
else
hv = f.hv
Expand Down

0 comments on commit 6a54f64

Please sign in to comment.