diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index dc0a7630e..4dcce49a7 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -1,6 +1,6 @@ module OptimizationEnzymeExt -import Optimization, Optimization.ArrayInterface +import Optimization, Optimization.ArrayInterface, Optimization.SparseArrays import Optimization.SciMLBase: OptimizationFunction import Optimization.LinearAlgebra: I import Optimization.ADTypes: AutoEnzyme @@ -47,11 +47,27 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x, return nothing end function hess(res, θ, args...) - vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0))) - - bθ = zeros(length(θ)) - vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ)) - + if f.hess_prototype === nothing + vdθ = Tuple((similar(r) for r in eachrow(I(length(θ)) * 1.0))) + bθ = zeros(length(θ)) + @show bθ + @show typeof(bθ) + vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ)) + @show vdbθ + @show typeof(vdbθ) + else + θ = SparseArrays.sparse(θ) + @show θ + vdθ = Tuple((similar(SparseArrays.sparse(r)) for r in eachrow(I(length(θ)) * 1.0))) + @show vdθ + @show typeof(vdθ) + bθ = SparseArrays.similar(θ) + @show bθ + @show typeof(bθ) + vdbθ = Tuple(similar(i) for i in eachrow(f.hess_prototype)) + @show vdbθ + @show typeof(vdbθ) + end Enzyme.autodiff(Enzyme.Forward, g, Enzyme.BatchDuplicated(θ, vdθ), diff --git a/test/ADtests.jl b/test/ADtests.jl index 7cca09fc0..f83c1b4ff 100644 --- a/test/ADtests.jl +++ b/test/ADtests.jl @@ -150,8 +150,8 @@ H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] G2 = Array{Float64}(undef, 2) H2 = Array{Float64}(undef, 2, 2) -optf = OptimizationFunction(rosenbrock, Optimization.AutoZygote(), cons = con2_c) -optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoZygote(), +optf = OptimizationFunction(rosenbrock, Optimization.AutoEnzyme(), cons = con2_c) +optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoEnzyme(), nothing, 2) optprob.grad(G2, x0) @test G1 == G2 @@ -167,15 +167,27 @@ H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] optprob.cons_h(H3, x0) H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] -optf = OptimizationFunction(rosenbrock, Optimization.AutoModelingToolkit(true, true), +optf = OptimizationFunction(rosenbrock, Optimization.AutoEnzyme(), cons = con2_c) optprob = Optimization.instantiate_function(optf, x0, - Optimization.AutoModelingToolkit(true, true), + Optimization.AutoEnzyme(), nothing, 2) using SparseArrays sH = sparse([1, 1, 2, 2], [1, 2, 1, 2], zeros(4)) -@test findnz(sH)[1:2] == findnz(optprob.hess_prototype)[1:2] +# @test findnz(sH)[1:2] == findnz(optprob.hess_prototype)[1:2] optprob.hess(sH, x0) + +sH = sparse([1, 1, 2, 2], [1, 2, 1, 2], zeros(4)) +optf = OptimizationFunction(rosenbrock, Optimization.AutoEnzyme(), + cons = con2_c, hess_prototype = sH) +optprob = Optimization.instantiate_function(optf, x0, + Optimization.AutoEnzyme(), + nothing, 2) +using SparseArrays + +# @test findnz(sH)[1:2] == findnz(optprob.hess_prototype)[1:2] +optprob.hess(sH, x0) + @test sH == H2 res = Array{Float64}(undef, 2) optprob.cons(res, x0)