Skip to content

Commit

Permalink
Faster Nested AD
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 24, 2024
1 parent 29e971e commit dba4b21
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
DeepEquilibriumNetworksZygoteExt = "Zygote"
DeepEquilibriumNetworksZygoteExt = ["ForwardDiff", "Zygote"]

[compat]
ADTypes = "0.2.5, 1"
Expand All @@ -38,6 +39,7 @@ ConstructionBase = "1"
DiffEqBase = "6.119"
ExplicitImports = "1.4.1"
FastClosures = "0.3"
ForwardDiff = "0.10.36"
Functors = "0.4.10"
LinearSolve = "2.21.2"
Lux = "0.5.37"
Expand Down
49 changes: 44 additions & 5 deletions ext/DeepEquilibriumNetworksZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,58 @@
module DeepEquilibriumNetworksZygoteExt

using ADTypes: AutoZygote
using ChainRulesCore: ChainRulesCore
using DeepEquilibriumNetworks: DEQs
using FastClosures: @closure
using ForwardDiff: ForwardDiff # This is a dependency of Zygote
using Lux: Lux, StatefulLuxLayer
using Statistics: mean
using Zygote: Zygote
using DeepEquilibriumNetworks: DEQs

@inline __tupleify(u) = @closure x -> (u, x)
const CRC = ChainRulesCore

@inline __tupleify(x) = @closure(u->(u, x))

## One day we will overload DI's APIs for Lux Layers and we can remove this
## Main challenge with overloading Zygote.pullback is that we need to return the correct
## tangent for the pullback to compute the correct gradient, which is quite hard. But
## wrapping the overall vjp is not that hard.
@inline function __compute_vector_jacobian_product(model::StatefulLuxLayer, ps, z, x, rng)
res, back = Zygote.pullback(model __tupleify(x), z)
return only(back(DEQs.__gaussian_like(rng, res)))
end

function CRC.rrule(
::typeof(__compute_vector_jacobian_product), model::StatefulLuxLayer, ps, z, x, rng)
res, back = Zygote.pullback(model __tupleify(x), z)
ε = DEQs.__gaussian_like(rng, res)
y = only(back(ε))
∇internal_gradient_capture = Δ -> begin
isa CRC.NoTangent || Δ isa CRC.ZeroTangent) &&
return ntuple(Returns(CRC.NoTangent()), 6)

Δ_ = reshape(CRC.unthunk(Δ), size(z))

Tag = typeof(ForwardDiff.Tag(model, eltype(z)))
partials = ForwardDiff.Partials{1, eltype(z)}.(tuple.(Δ_))
z_dual = ForwardDiff.Dual{Tag, eltype(z), 1}.(z, partials)

_, pb_f = Zygote.pullback((x1, x2, p) -> model((x1, x2), p), z_dual, x, ps)
∂z_duals, ∂x_duals, ∂ps_duals = pb_f(ε)

∂z = Lux.__partials(Tag, ∂z_duals, 1)
∂x = Lux.__partials(Tag, ∂x_duals, 1)
∂ps = Lux.__partials(Tag, ∂ps_duals, 1)

return CRC.NoTangent(), CRC.NoTangent(), ∂ps, ∂z, ∂x, CRC.NoTangent()
end
return y, ∇internal_gradient_capture
end

## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33
## FIXME: This will be broken in the new Lux release let's fix this
function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng)
res, back = Zygote.pullback(model __tupleify, z)
vjp_z = only(back(DEQs.__gaussian_like(rng, res)))
return mean(abs2, vjp_z)
return mean(abs2, __compute_vector_jacobian_product(model, model.ps, z, x, rng))
end

end

0 comments on commit dba4b21

Please sign in to comment.