Skip to content

Commit

Permalink
Add sparsereversediff with tape compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 17, 2023
1 parent 4e54703 commit 5e92e40
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 16 deletions.
117 changes: 101 additions & 16 deletions ext/OptimizationSparseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -485,14 +485,24 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
lag_h, f.lag_hess_prototype)
end

struct OptimizationSparseReverseTag end

function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
p = SciMLBase.NullParameters(),
num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))

if f.grad === nothing
cfg = ReverseDiff.GradientConfig(x)
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, cfg)
if adtype.compile
_tape = ReverseDiff.GradientTape(_f, x)
tape = ReverseDiff.compile(_tape)
grad = function (res, θ, args...)
ReverseDiff.gradient!(res, tape, θ)

Check warning on line 500 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L496-L500

Added lines #L496 - L500 were not covered by tests
end
else
cfg = ReverseDiff.GradientConfig(x)
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, cfg)

Check warning on line 504 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L503-L504

Added lines #L503 - L504 were not covered by tests
end
else
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
end
Expand All @@ -502,9 +512,23 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
if f.hess === nothing
hess_sparsity = Symbolics.hessian_sparsity(_f, x)
hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity))
hess = function (res, θ, args...)
res .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = hess_colors, sparsity = hess_sparsity) do θ
ReverseDiff.gradient(x -> _f(x, args...), θ)
if adtype.compile
T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(x))
xdual = ForwardDiff.Dual{typeof(T),eltype(x),length(x)}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), length(x))...,))))
h_tape = ReverseDiff.GradientTape(_f, xdual)
htape = ReverseDiff.compile(h_tape)
function g(res1, θ)
ReverseDiff.gradient!(res1, htape, θ)

Check warning on line 521 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L515-L521

Added lines #L515 - L521 were not covered by tests
end
jaccfg = ForwardColorJacCache(g, x; tag = typeof(T), colorvec = hess_colors, sparsity = hess_sparsity)
hess = function (res, θ, args...)
SparseDiffTools.forwarddiff_color_jacobian!(res, g, θ, jaccfg)

Check warning on line 525 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L523-L525

Added lines #L523 - L525 were not covered by tests
end
else
hess = function (res, θ, args...)
res .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = hess_colors, sparsity = hess_sparsity) do θ
ReverseDiff.gradient(x -> _f(x, args...), θ)

Check warning on line 530 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L528-L530

Added lines #L528 - L530 were not covered by tests
end
end
end
else
Expand Down Expand Up @@ -553,10 +577,28 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(x))
conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
cons_h = function (res, θ)
for i in 1:num_cons
res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) do θ
ReverseDiff.gradient(fncs[i], θ)
if adtype.compile
T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(x))
xduals = [ForwardDiff.Dual{typeof(T),eltype(x),maximum(conshess_colors[i])}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), maximum(conshess_colors[i]))...,)))) for i in 1:num_cons]
consh_tapes = [ReverseDiff.GradientTape(fncs[i], xduals[i]) for i in 1:num_cons]
conshtapes = ReverseDiff.compile.(consh_tapes)
function grad_cons(res1, θ, htape)
ReverseDiff.gradient!(res1, htape, θ)

Check warning on line 586 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L580-L586

Added lines #L580 - L586 were not covered by tests
end
gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardColorJacCache(gs[i], x; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons]
println(jaccfgs)
cons_h = function (res, θ)
for i in 1:num_cons
SparseDiffTools.forwarddiff_color_jacobian!(res[i], gs[i], θ, jaccfgs[i])
end

Check warning on line 594 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L588-L594

Added lines #L588 - L594 were not covered by tests
end
else
cons_h = function (res, θ)
for i in 1:num_cons
res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) do θ
ReverseDiff.gradient(fncs[i], θ)

Check warning on line 600 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L597-L600

Added lines #L597 - L600 were not covered by tests
end
end
end
end
Expand Down Expand Up @@ -585,7 +627,16 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
_f = (θ, args...) -> first(f.f(θ, cache.p, args...))

if f.grad === nothing
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ)
if adtype.compile
_tape = ReverseDiff.GradientTape(_f, cache.u0)
tape = ReverseDiff.compile(_tape)
grad = function (res, θ, args...)
ReverseDiff.gradient!(res, tape, θ)

Check warning on line 634 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L630-L634

Added lines #L630 - L634 were not covered by tests
end
else
cfg = ReverseDiff.GradientConfig(cache.u0)
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, cfg)

Check warning on line 638 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L637-L638

Added lines #L637 - L638 were not covered by tests
end
else
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
end
Expand All @@ -595,8 +646,24 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
if f.hess === nothing
hess_sparsity = Symbolics.hessian_sparsity(_f, cache.u0)
hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity))
hess = function (res, θ, args...)
SparseDiffTools.forwarddiff_color_jacobian(res, _θ -> ReverseDiff.gradient(x -> _f(x, args...), _θ), θ, sparsity = hess_sparsity, colorvec = hess_colors)
if adtype.compile
T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(cache.u0))
xdual = ForwardDiff.Dual{typeof(T),eltype(cache.u0),length(cache.u0)}.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), length(cache.u0))...,))))
h_tape = ReverseDiff.GradientTape(_f, xdual)
htape = ReverseDiff.compile(h_tape)
function g(res1, θ)
ReverseDiff.gradient!(res1, htape, θ)

Check warning on line 655 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L649-L655

Added lines #L649 - L655 were not covered by tests
end
jaccfg = ForwardColorJacCache(g, cache.u0; tag = typeof(T), colorvec = hess_colors, sparsity = hess_sparsity)
hess = function (res, θ, args...)
SparseDiffTools.forwarddiff_color_jacobian!(res, g, θ, jaccfg)

Check warning on line 659 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L657-L659

Added lines #L657 - L659 were not covered by tests
end
else
hess = function (res, θ, args...)
res .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = hess_colors, sparsity = hess_sparsity) do θ
ReverseDiff.gradient(x -> _f(x, args...), θ)

Check warning on line 664 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L662-L664

Added lines #L662 - L664 were not covered by tests
end
end
end
else
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
Expand Down Expand Up @@ -644,10 +711,28 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(cache.u0))
conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
cons_h = function (res, θ)
for i in 1:num_cons
res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) do θ
ReverseDiff.gradient(fncs[i], θ)
if adtype.compile
T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(cache.u0))
xduals = [ForwardDiff.Dual{typeof(T),eltype(cache.u0),maximum(conshess_colors[i])}.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), maximum(conshess_colors[i]))...,)))) for i in 1:num_cons]
consh_tapes = [ReverseDiff.GradientTape(fncs[i], xduals[i]) for i in 1:num_cons]
conshtapes = ReverseDiff.compile.(consh_tapes)
function grad_cons(res1, θ, htape)
ReverseDiff.gradient!(res1, htape, θ)

Check warning on line 720 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L714-L720

Added lines #L714 - L720 were not covered by tests
end
gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardColorJacCache(gs[i], cache.u0; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons]
println(jaccfgs)
cons_h = function (res, θ)
for i in 1:num_cons
SparseDiffTools.forwarddiff_color_jacobian!(res[i], gs[i], θ, jaccfgs[i])
end

Check warning on line 728 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L722-L728

Added lines #L722 - L728 were not covered by tests
end
else
cons_h = function (res, θ)
for i in 1:num_cons
res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) do θ
ReverseDiff.gradient(fncs[i], θ)

Check warning on line 734 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L731-L734

Added lines #L731 - L734 were not covered by tests
end
end
end
end
Expand Down
23 changes: 23 additions & 0 deletions test/ADtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,28 @@ sol = solve(prob, Optim.KrylovTrustRegion())
sol = solve(prob, Optimisers.ADAM(0.1), maxiters = 1000)
@test 10 * sol.objective < l1

optf = OptimizationFunction(rosenbrock, Optimization.AutoSparseReverseDiff(), cons = con2_c)
optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoSparseReverseDiff(true),
nothing, 2)
G2 = Array{Float64}(undef, 2)
optprob.grad(G2, x0)
@test G1G2 rtol=1e-4
H2 = Array{Float64}(undef, 2, 2)
optprob.hess(H2, x0)
@test H1H2 rtol=1e-4
res = Array{Float64}(undef, 2)
optprob.cons(res, x0)
@test res[0.0, 0.0] atol=1e-4
optprob.cons(res, [1.0, 2.0])
@test res [5.0, 0.682941969615793]
J = Array{Float64}(undef, 2, 2)
optprob.cons_j(J, [5.0, 3.0])
@test all(isapprox(J, [10.0 6.0; -0.149013 -0.958924]; rtol = 1e-3))
H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)]
optprob.cons_h(H3, x0)
@test H3 [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]]


optf = OptimizationFunction(rosenbrock, Optimization.AutoSparseReverseDiff(), cons = con2_c)
optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoSparseReverseDiff(),
nothing, 2)
Expand All @@ -542,6 +564,7 @@ H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)]
optprob.cons_h(H3, x0)
@test H3 [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]]


optf = OptimizationFunction(rosenbrock, Optimization.AutoSparseReverseDiff())
optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoSparseReverseDiff(),
nothing)
Expand Down

0 comments on commit 5e92e40

Please sign in to comment.