Skip to content

Commit

Permalink
Eliminated some runtime dispatch and other things
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 25, 2023
1 parent 09cc1e7 commit 57c2f39
Showing 1 changed file with 51 additions and 24 deletions.
75 changes: 51 additions & 24 deletions ext/OptimizationSparseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -561,16 +561,18 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
cons_jac_prototype = f.cons_jac_prototype
cons_jac_colorvec = f.cons_jac_colorvec
if cons !== nothing && f.cons_j === nothing
cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
zeros(eltype(x), num_cons),
x)
cons_jac_colorvec = matrix_colors(cons_jac_prototype)
jaccache = ForwardColorJacCache(cons, x;
colorvec = cons_jac_colorvec,
sparsity = cons_jac_prototype,
dx = zeros(eltype(x), num_cons))
cons_j = function (J, θ)
forwarddiff_color_jacobian!(J, cons, θ, jaccache)
jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), SparseDiffTools.SymbolicsSparsityDetection(), cons_oop, x, fx = zeros(eltype(x), num_cons))

Check warning on line 564 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L564

Added line #L564 was not covered by tests
# let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons
# ForwardColorJacCache(cons, θ;
# colorvec = cons_jac_colorvec,
# sparsity = cons_jac_prototype,
# dx = zeros(eltype(θ), num_cons))
# end
cons_jac_prototype = jaccache.jac_prototype
cons_jac_colorvec = jaccache.coloring
cons_j = function (J, θ, args...;cons = cons, cache = jaccache.cache)
forwarddiff_color_jacobian!(J, cons, θ, cache)
return

Check warning on line 575 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L571-L575

Added lines #L571 - L575 were not covered by tests
end
else
cons_j = (J, θ) -> f.cons_j(J, θ, p)
Expand All @@ -592,7 +594,7 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
end
gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardColorJacCache(gs[i], x; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons]
cons_h = function (res, θ)
cons_h = function (res, θ, args...)

Check warning on line 597 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L597

Added line #L597 was not covered by tests
for i in 1:num_cons
SparseDiffTools.forwarddiff_color_jacobian!(res[i], gs[i], θ, jaccfgs[i])
end
Expand Down Expand Up @@ -692,23 +694,32 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
if f.cons === nothing
cons = nothing
else
cons = (res, θ) -> f.cons(res, θ, cache.p)
cons = function (res, θ)
f.cons(res, θ, cache.p)
return

Check warning on line 699 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L697-L699

Added lines #L697 - L699 were not covered by tests
end
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
end

cons_jac_prototype = f.cons_jac_prototype
cons_jac_colorvec = f.cons_jac_colorvec
if cons !== nothing && f.cons_j === nothing
cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
zeros(eltype(cache.u0), num_cons),
cache.u0)
cons_jac_colorvec = matrix_colors(cons_jac_prototype)
jaccache = ForwardColorJacCache(cons, cache.u0;
colorvec = cons_jac_colorvec,
sparsity = cons_jac_prototype,
dx = zeros(eltype(cache.u0), num_cons))
# cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
# zeros(eltype(cache.u0), num_cons),
# cache.u0)
# cons_jac_colorvec = matrix_colors(cons_jac_prototype)
jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), SparseDiffTools.SymbolicsSparsityDetection(), cons_oop, cache.u0, fx = zeros(eltype(cache.u0), num_cons))

Check warning on line 711 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L711

Added line #L711 was not covered by tests
# let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons
# ForwardColorJacCache(cons, θ;
# colorvec = cons_jac_colorvec,
# sparsity = cons_jac_prototype,
# dx = zeros(eltype(θ), num_cons))
# end
cons_jac_prototype = jaccache.jac_prototype
cons_jac_colorvec = jaccache.coloring

Check warning on line 719 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L718-L719

Added lines #L718 - L719 were not covered by tests
cons_j = function (J, θ)
forwarddiff_color_jacobian!(J, cons, θ, jaccache)
forwarddiff_color_jacobian!(J, cons, θ, jaccache.cache)
return

Check warning on line 722 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L721-L722

Added lines #L721 - L722 were not covered by tests
end
else
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
Expand All @@ -717,8 +728,18 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
conshess_sparsity = f.cons_hess_prototype
conshess_colors = f.cons_hess_colorvec
if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(cache.u0))
fncs = map(1:num_cons) do i
function (x)
res = zeros(eltype(x), num_cons)
f.cons(res, x, cache.p)
return res[i]

Check warning on line 735 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L731-L735

Added lines #L731 - L735 were not covered by tests
end
end
conshess_sparsity = map(1:num_cons) do i
let fnc = fncs[i], θ = cache.u0
Symbolics.hessian_sparsity(fnc, θ)

Check warning on line 740 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L738-L740

Added lines #L738 - L740 were not covered by tests
end
end
conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
if adtype.compile
T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(cache.u0))
Expand All @@ -728,7 +749,13 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
function grad_cons(res1, θ, htape)
ReverseDiff.gradient!(res1, htape, θ)
end
gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons]
gs = let conshtapes = conshtapes
map(1:num_cons) do i
function (res1, x)
grad_cons(res1, x, conshtapes[i])

Check warning on line 755 in ext/OptimizationSparseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationSparseDiffExt.jl#L752-L755

Added lines #L752 - L755 were not covered by tests
end
end
end
jaccfgs = [ForwardColorJacCache(gs[i], cache.u0; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
Expand Down

0 comments on commit 57c2f39

Please sign in to comment.