From 57c2f39d0aae6d542c511d129e22f7f6ab7edc27 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 25 Sep 2023 18:38:48 -0400 Subject: [PATCH] Eliminated some runtime dispatch and other things --- ext/OptimizationSparseDiffExt.jl | 75 ++++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 24 deletions(-) diff --git a/ext/OptimizationSparseDiffExt.jl b/ext/OptimizationSparseDiffExt.jl index c1990591f..a115b2710 100644 --- a/ext/OptimizationSparseDiffExt.jl +++ b/ext/OptimizationSparseDiffExt.jl @@ -561,16 +561,18 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff, cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec if cons !== nothing && f.cons_j === nothing - cons_jac_prototype = Symbolics.jacobian_sparsity(cons, - zeros(eltype(x), num_cons), - x) - cons_jac_colorvec = matrix_colors(cons_jac_prototype) - jaccache = ForwardColorJacCache(cons, x; - colorvec = cons_jac_colorvec, - sparsity = cons_jac_prototype, - dx = zeros(eltype(x), num_cons)) - cons_j = function (J, θ) - forwarddiff_color_jacobian!(J, cons, θ, jaccache) + jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), SparseDiffTools.SymbolicsSparsityDetection(), cons_oop, x, fx = zeros(eltype(x), num_cons)) + # let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons + # ForwardColorJacCache(cons, θ; + # colorvec = cons_jac_colorvec, + # sparsity = cons_jac_prototype, + # dx = zeros(eltype(θ), num_cons)) + # end + cons_jac_prototype = jaccache.jac_prototype + cons_jac_colorvec = jaccache.coloring + cons_j = function (J, θ, args...;cons = cons, cache = jaccache.cache) + forwarddiff_color_jacobian!(J, cons, θ, cache) + return end else cons_j = (J, θ) -> f.cons_j(J, θ, p) @@ -592,7 +594,7 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff, end gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons] jaccfgs = [ForwardColorJacCache(gs[i], x; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons] - cons_h = function (res, θ) + cons_h = function (res, θ, args...) for i in 1:num_cons SparseDiffTools.forwarddiff_color_jacobian!(res[i], gs[i], θ, jaccfgs[i]) end @@ -692,23 +694,32 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, if f.cons === nothing cons = nothing else - cons = (res, θ) -> f.cons(res, θ, cache.p) + cons = function (res, θ) + f.cons(res, θ, cache.p) + return + end cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) end cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec if cons !== nothing && f.cons_j === nothing - cons_jac_prototype = Symbolics.jacobian_sparsity(cons, - zeros(eltype(cache.u0), num_cons), - cache.u0) - cons_jac_colorvec = matrix_colors(cons_jac_prototype) - jaccache = ForwardColorJacCache(cons, cache.u0; - colorvec = cons_jac_colorvec, - sparsity = cons_jac_prototype, - dx = zeros(eltype(cache.u0), num_cons)) + # cons_jac_prototype = Symbolics.jacobian_sparsity(cons, + # zeros(eltype(cache.u0), num_cons), + # cache.u0) + # cons_jac_colorvec = matrix_colors(cons_jac_prototype) + jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), SparseDiffTools.SymbolicsSparsityDetection(), cons_oop, cache.u0, fx = zeros(eltype(cache.u0), num_cons)) + # let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons + # ForwardColorJacCache(cons, θ; + # colorvec = cons_jac_colorvec, + # sparsity = cons_jac_prototype, + # dx = zeros(eltype(θ), num_cons)) + # end + cons_jac_prototype = jaccache.jac_prototype + cons_jac_colorvec = jaccache.coloring cons_j = function (J, θ) - forwarddiff_color_jacobian!(J, cons, θ, jaccache) + forwarddiff_color_jacobian!(J, cons, θ, jaccache.cache) + return end else cons_j = (J, θ) -> f.cons_j(J, θ, cache.p) @@ -717,8 +728,18 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(cache.u0)) + fncs = map(1:num_cons) do i + function (x) + res = zeros(eltype(x), num_cons) + f.cons(res, x, cache.p) + return res[i] + end + end + conshess_sparsity = map(1:num_cons) do i + let fnc = fncs[i], θ = cache.u0 + Symbolics.hessian_sparsity(fnc, θ) + end + end conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity) if adtype.compile T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(cache.u0)) @@ -728,7 +749,13 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, function grad_cons(res1, θ, htape) ReverseDiff.gradient!(res1, htape, θ) end - gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons] + gs = let conshtapes = conshtapes + map(1:num_cons) do i + function (res1, x) + grad_cons(res1, x, conshtapes[i]) + end + end + end jaccfgs = [ForwardColorJacCache(gs[i], cache.u0; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons] cons_h = function (res, θ) for i in 1:num_cons