Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate some runtime dispatch and other things #597

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 51 additions & 24 deletions ext/OptimizationSparseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -561,16 +561,18 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
cons_jac_prototype = f.cons_jac_prototype
cons_jac_colorvec = f.cons_jac_colorvec
if cons !== nothing && f.cons_j === nothing
cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
zeros(eltype(x), num_cons),
x)
cons_jac_colorvec = matrix_colors(cons_jac_prototype)
jaccache = ForwardColorJacCache(cons, x;
colorvec = cons_jac_colorvec,
sparsity = cons_jac_prototype,
dx = zeros(eltype(x), num_cons))
cons_j = function (J, θ)
forwarddiff_color_jacobian!(J, cons, θ, jaccache)
jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), SparseDiffTools.SymbolicsSparsityDetection(), cons_oop, x, fx = zeros(eltype(x), num_cons))
# let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons
# ForwardColorJacCache(cons, θ;
# colorvec = cons_jac_colorvec,
# sparsity = cons_jac_prototype,
# dx = zeros(eltype(θ), num_cons))
# end
cons_jac_prototype = jaccache.jac_prototype
cons_jac_colorvec = jaccache.coloring
cons_j = function (J, θ, args...;cons = cons, cache = jaccache.cache)
forwarddiff_color_jacobian!(J, cons, θ, cache)
return
end
else
cons_j = (J, θ) -> f.cons_j(J, θ, p)
Expand All @@ -592,7 +594,7 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
end
gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardColorJacCache(gs[i], x; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons]
cons_h = function (res, θ)
cons_h = function (res, θ, args...)
for i in 1:num_cons
SparseDiffTools.forwarddiff_color_jacobian!(res[i], gs[i], θ, jaccfgs[i])
end
Expand Down Expand Up @@ -692,23 +694,32 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
if f.cons === nothing
cons = nothing
else
cons = (res, θ) -> f.cons(res, θ, cache.p)
cons = function (res, θ)
f.cons(res, θ, cache.p)
return
end
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
end

cons_jac_prototype = f.cons_jac_prototype
cons_jac_colorvec = f.cons_jac_colorvec
if cons !== nothing && f.cons_j === nothing
cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
zeros(eltype(cache.u0), num_cons),
cache.u0)
cons_jac_colorvec = matrix_colors(cons_jac_prototype)
jaccache = ForwardColorJacCache(cons, cache.u0;
colorvec = cons_jac_colorvec,
sparsity = cons_jac_prototype,
dx = zeros(eltype(cache.u0), num_cons))
# cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
# zeros(eltype(cache.u0), num_cons),
# cache.u0)
# cons_jac_colorvec = matrix_colors(cons_jac_prototype)
jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), SparseDiffTools.SymbolicsSparsityDetection(), cons_oop, cache.u0, fx = zeros(eltype(cache.u0), num_cons))
# let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons
# ForwardColorJacCache(cons, θ;
# colorvec = cons_jac_colorvec,
# sparsity = cons_jac_prototype,
# dx = zeros(eltype(θ), num_cons))
# end
cons_jac_prototype = jaccache.jac_prototype
cons_jac_colorvec = jaccache.coloring
cons_j = function (J, θ)
forwarddiff_color_jacobian!(J, cons, θ, jaccache)
forwarddiff_color_jacobian!(J, cons, θ, jaccache.cache)
return
end
else
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
Expand All @@ -717,8 +728,18 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
conshess_sparsity = f.cons_hess_prototype
conshess_colors = f.cons_hess_colorvec
if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(cache.u0))
fncs = map(1:num_cons) do i
function (x)
res = zeros(eltype(x), num_cons)
f.cons(res, x, cache.p)
return res[i]
end
end
conshess_sparsity = map(1:num_cons) do i
let fnc = fncs[i], θ = cache.u0
Symbolics.hessian_sparsity(fnc, θ)
end
end
conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
if adtype.compile
T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(cache.u0))
Expand All @@ -728,7 +749,13 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
function grad_cons(res1, θ, htape)
ReverseDiff.gradient!(res1, htape, θ)
end
gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons]
gs = let conshtapes = conshtapes
map(1:num_cons) do i
function (res1, x)
grad_cons(res1, x, conshtapes[i])
end
end
end
jaccfgs = [ForwardColorJacCache(gs[i], cache.u0; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
Expand Down