Skip to content

Commit

Permalink
Merge pull request #590 from SciML/reversediff_tape_compilaton
Browse files Browse the repository at this point in the history
Setup tape compilation for ReverseDiff
  • Loading branch information
Vaibhavdixit02 authored Sep 19, 2023
2 parents c3f0da8 + 0f6c2f1 commit 1167183
Show file tree
Hide file tree
Showing 7 changed files with 346 additions and 63 deletions.
22 changes: 12 additions & 10 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Optimization.LinearAlgebra: I
import Optimization.ADTypes: AutoEnzyme
isdefined(Base, :get_extension) ? (using Enzyme) : (using ..Enzyme)

@inline function firstapply(f, θ, p, args...)
@inline function firstapply(f::F, θ, p, args...) where F
res = f(θ, p, args...)
if isa(res, AbstractFloat)
res
Expand All @@ -20,15 +20,17 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
num_cons = 0)

if f.grad === nothing
function grad(res, θ, args...)
res .= zero(eltype(res))
Enzyme.autodiff(Enzyme.Reverse,
Const(firstapply),
Active,
Const(f.f),
Enzyme.Duplicated(θ, res),
Const(p),
args...)
grad = let
function (res, θ, args...)
res .= zero(eltype(res))
Enzyme.autodiff(Enzyme.Reverse,
Const(firstapply),
Active,
Const(f.f),
Enzyme.Duplicated(θ, res),
Const(p),
args...)
end
end
else
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
Expand Down
157 changes: 130 additions & 27 deletions ext/OptimizationReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,60 @@ import Optimization.ADTypes: AutoReverseDiff
isdefined(Base, :get_extension) ? (using ReverseDiff, ReverseDiff.ForwardDiff) :
(using ..ReverseDiff, ..ReverseDiff.ForwardDiff)

struct OptimizationReverseDiffTag end

function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
p = SciMLBase.NullParameters(),
num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))

if f.grad === nothing
cfg = ReverseDiff.GradientConfig(x)
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, cfg)
if adtype.compile
_tape = ReverseDiff.GradientTape(_f, x)
tape = ReverseDiff.compile(_tape)
grad = function (res, θ, args...)
ReverseDiff.gradient!(res, tape, θ)
end
else
cfg = ReverseDiff.GradientConfig(x)
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, cfg)
end
else
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
end

if f.hess === nothing
hess = function (res, θ, args...)
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
if adtype.compile
T = ForwardDiff.Tag(OptimizationReverseDiffTag(),eltype(x))
xdual = ForwardDiff.Dual{typeof(T),eltype(x),length(x)}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), length(x))...,))))
h_tape = ReverseDiff.GradientTape(_f, xdual)
htape = ReverseDiff.compile(h_tape)
function g(θ)
res1 = zeros(eltype(θ), length(θ))
ReverseDiff.gradient!(res1, htape, θ)
end
jaccfg = ForwardDiff.JacobianConfig(g, x, ForwardDiff.Chunk(x), T)
hess = function (res, θ, args...)
ForwardDiff.jacobian!(res, g, θ, jaccfg, Val{false}())
end
else
hess = function (res, θ, args...)
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
end
end
else
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
end

if f.hv === nothing
hv = function (H, θ, v, args...)
= ForwardDiff.Dual.(θ, v)
res = similar(_θ)
grad(res, _θ, args...)
H .= getindex.(ForwardDiff.partials.(res), 1)
# _θ = ForwardDiff.Dual.(θ, v)
# res = similar(_θ)
# grad(res, _θ, args...)
# H .= getindex.(ForwardDiff.partials.(res), 1)
res = zeros(length(θ), length(θ))
hess(res, θ, args...)
H .= res * v
end
else
hv = f.hv
Expand All @@ -46,19 +74,43 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
end

if cons !== nothing && f.cons_j === nothing
cjconfig = ReverseDiff.JacobianConfig(x)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
if adtype.compile
_jac_tape = ReverseDiff.JacobianTape(cons_oop, x)
jac_tape = ReverseDiff.compile(_jac_tape)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, jac_tape, θ)
end
else
cjconfig = ReverseDiff.JacobianConfig(x)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
end
end
else
cons_j = (J, θ) -> f.cons_j(J, θ, p)
end

if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
ReverseDiff.hessian!(res[i], fncs[i], θ)
if adtype.compile
consh_tapes = ReverseDiff.GradientTape.(fncs, Ref(xdual))
conshtapes = ReverseDiff.compile.(consh_tapes)
function grad_cons(θ, htape)
res1 = zeros(eltype(θ), length(θ))
ReverseDiff.gradient!(res1, htape, θ)
end
gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardDiff.JacobianConfig(gs[i], x, ForwardDiff.Chunk(x), T) for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
ForwardDiff.jacobian!(res[i], gs[i], θ, jaccfgs[i], Val{false}())
end
end
else
cons_h = function (res, θ)
for i in 1:num_cons
ReverseDiff.hessian!(res[i], fncs[i], θ)
end
end
end
else
Expand All @@ -83,25 +135,52 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
_f = (θ, args...) -> first(f.f(θ, cache.p, args...))

if f.grad === nothing
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ)
if adtype.compile
_tape = ReverseDiff.GradientTape(_f, cache.u0)
tape = ReverseDiff.compile(_tape)
grad = function (res, θ, args...)
ReverseDiff.gradient!(res, tape, θ)
end
else
cfg = ReverseDiff.GradientConfig(cache.u0)
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, cfg)
end
else
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
end

if f.hess === nothing
hess = function (res, θ, args...)
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
if adtype.compile
T = ForwardDiff.Tag(OptimizationReverseDiffTag(),eltype(cache.u0))
xdual = ForwardDiff.Dual{typeof(T),eltype(cache.u0),length(cache.u0)}.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), length(cache.u0))...,))))
h_tape = ReverseDiff.GradientTape(_f, xdual)
htape = ReverseDiff.compile(h_tape)
function g(θ)
res1 = zeros(eltype(θ), length(θ))
ReverseDiff.gradient!(res1, htape, θ)
end
jaccfg = ForwardDiff.JacobianConfig(g, cache.u0, ForwardDiff.Chunk(cache.u0), T)
hess = function (res, θ, args...)
ForwardDiff.jacobian!(res, g, θ, jaccfg, Val{false}())
end
else
hess = function (res, θ, args...)
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
end
end
else
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
end

if f.hv === nothing
hv = function (H, θ, v, args...)
= ForwardDiff.Dual.(θ, v)
res = similar(_θ)
grad(res, _θ, args...)
H .= getindex.(ForwardDiff.partials.(res), 1)
# _θ = ForwardDiff.Dual.(θ, v)
# res = similar(_θ)
# grad(res, θ, args...)
# H .= getindex.(ForwardDiff.partials.(res), 1)
res = zeros(length(θ), length(θ))
hess(res, θ, args...)
H .= res * v
end
else
hv = f.hv
Expand All @@ -115,19 +194,43 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
end

if cons !== nothing && f.cons_j === nothing
cjconfig = ReverseDiff.JacobianConfig(cache.u0)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
if adtype.compile
_jac_tape = ReverseDiff.JacobianTape(cons_oop, cache.u0)
jac_tape = ReverseDiff.compile(_jac_tape)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, jac_tape, θ)
end
else
cjconfig = ReverseDiff.JacobianConfig(cache.u0)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
end
end
else
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
end

if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
ReverseDiff.hessian!(res[i], fncs[i], θ)
if adtype.compile
consh_tapes = ReverseDiff.GradientTape.(fncs, Ref(xdual))
conshtapes = ReverseDiff.compile.(consh_tapes)
function grad_cons(θ, htape)
res1 = zeros(eltype(θ), length(θ))
ReverseDiff.gradient!(res1, htape, θ)
end
gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardDiff.JacobianConfig(gs[i], cache.u0, ForwardDiff.Chunk(cache.u0), T) for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
ForwardDiff.jacobian!(res[i], gs[i], θ, jaccfgs[i], Val{false}())
end
end
else
cons_h = function (res, θ)
for i in 1:num_cons
ReverseDiff.hessian!(res[i], fncs[i], θ)
end
end
end
else
Expand Down
Loading

0 comments on commit 1167183

Please sign in to comment.