Skip to content

Commit

Permalink
Capture ForwardDiff.gradient calls
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 24, 2024
1 parent bffeffb commit 131fa74
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 35 deletions.
3 changes: 3 additions & 0 deletions bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "0.2"
5 changes: 5 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
Expand All @@ -19,6 +23,7 @@ Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
Expand Down
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ pages = [
"manual/gpu_management.md",
"manual/migrate_from_flux.md",
"manual/weight_initializers.md",
"manual/distributed_utils.md"
"manual/distributed_utils.md",
"manual/nested_autodiff.md"
],
"API Reference" => [
"Lux" => [
Expand Down
3 changes: 2 additions & 1 deletion docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ export default defineConfig({
{ text: 'GPU Management', link: '/manual/gpu_management' },
{ text: 'Migrating from Flux to Lux', link: '/manual/migrate_from_flux' },
{ text: 'Initializing Weights', link: '/manual/weight_initializers' },
{ text: 'Distributed Data Parallel Training', link: '/manual/distributed_utils' },]
{ text: 'Distributed Data Parallel Training', link: '/manual/distributed_utils' },
{ text: 'Nested Automatic Differentiation', link: '/manual/nested_autodiff' },]
},
"/api/": {
text: 'API Reference', collapsed: false, items: [
Expand Down
18 changes: 18 additions & 0 deletions docs/src/api/Building_Blocks/LuxLib.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@ CurrentModule = LuxLib
Pages = ["LuxLib.md"]
```

## Fully Connected Layers

```@docs
fused_dense_bias_activation
```

## Convolutional Layers

```@docs
fused_conv_bias_activation
```

## Dropout

```@docs
Expand All @@ -27,3 +39,9 @@ groupnorm
instancenorm
layernorm
```

## Apply Activation

```@docs
fast_activation!!
```
130 changes: 130 additions & 0 deletions docs/src/manual/nested_autodiff.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Nested Automatic Differentiation

!!! note

This is a relatively new feature in Lux, so there might be some rough edges. If you
encounter any issues, please let us know by opening an issue on the
[GitHub repository](https://github.com/LuxDL/Lux.jl).

In this manual, we will explore how to use automatic differentiation (AD) inside your layers
or loss functions and have Lux automatically switch the AD backend with a faster one when
needed.

!!! tip

Don't wan't Lux to do this switching for you? You can disable it by setting the
`DisableAutomaticNestedADSwitching` Preference to `true`.

Remember that if you are using ForwardDiff inside a Zygote call, it will drop gradients
(with a warning message), so it is not recommended to use this combination.

Let's explore this using some questions that were posted on the
[Julia Discourse forum](https://discourse.julialang.org/).

```@example nested_ad
using Lux, LinearAlgebra, Zygote, ForwardDiff, Random
using ComponentArrays, FiniteDiff
```

First let's set the stage using some minor changes that need to be made for this feature to
work:

1. Switching only works if a [`StatefulLuxLayer`](@ref) is being used, with the following
function calls:
- `(<some-function> ∘ <StatefulLuxLayer>)(x::AbstractArray)`
- `(<StatefulLuxLayer> ∘ <some-function>)(x::AbstractArray)`
- `(<StatefulLuxLayer>)(x::AbstractArray)`
2. Currently we have custom routines implemented for:
- `Zygote.<gradient|jacobian>`
- `ForwardDiff.<gradient|jacobian>`
3. Switching only happens for `ChainRules` compatible AD libraries.

We plan to capture `DifferentiationInterface`, `Zygote.pullback`, and `Enzyme.autodiff`
calls in the future (PRs are welcome).

## Nested AD for Neural Differential Equations (DEs)

This problem comes from `@facusapienza` on [Discourse](https://discourse.julialang.org/t/nested-and-different-ad-methods-altogether-how-to-add-ad-calculations-inside-my-loss-function-when-using-neural-differential-equations/108985).
In this case, we want to add a regularization term to the neural DE based on first-order
derivatives. The neural DE part is not important here and we can demonstrate this easily
with a standard neural network.

```@example nested_ad
function loss_function1(model, x, ps, st, y)
# Make it a stateful layer
smodel = StatefulLuxLayer(model, ps, st)
ŷ = smodel(x)
loss_emp = sum(abs2, ŷ .- y)
# You can use `Zygote.jacobian` as well but ForwardDiff tends to be more efficient here
J = ForwardDiff.jacobian(smodel, x)
loss_reg = abs2(norm(J))
return loss_emp + loss_reg
end
# Using Batchnorm to show that it is possible
model = Chain(Dense(2 => 4, tanh), BatchNorm(4), Dense(4 => 2))
ps, st = Lux.setup(Xoshiro(0), model)
x = rand(Xoshiro(0), Float32, 2, 10)
y = rand(Xoshiro(11), Float32, 2, 10)
loss_function1(model, x, ps, st, y)
```

So our loss function works, let's take the gradient (forward diff doesn't nest nicely here):

```@example nested_ad
_, ∂x, ∂ps, _, _ = Zygote.gradient(loss_function1, model, x, ps, st, y)
```

Now let's verify the gradients using finite differences:

```@example nested_ad
∂x_fd = FiniteDiff.finite_difference_gradient(x -> loss_function1(model, x, ps, st, y), x)
∂ps_fd = FiniteDiff.finite_difference_gradient(ps -> loss_function1(model, x, ps, st, y),
ComponentArray(ps))
println("∞-norm(∂x - ∂x_fd): ", norm(∂x .- ∂x_fd, Inf))
println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf))
nothing; # hide
```

That's pretty good, of course you will have some error from the finite differences
calculation.

## Loss Function contains Gradient Calculation

Ok here I am going to cheat a bit. This comes from a discussion on nested AD for PINNs
on [Discourse](https://discourse.julialang.org/t/is-it-possible-to-do-nested-ad-elegantly-in-julia-pinns/98888/21).
As the consensus there, we shouldn't use nested AD for 3rd or higher order differentiation.
Note that in the example there, the user uses `ForwardDiff.derivative` but we will use
`ForwardDiff.gradient` instead, as we typically deal with array inputs and outputs.

```@example nested_ad
function loss_function2(model, t, ps, st)
smodel = StatefulLuxLayer(model, ps, st)
ŷ = only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ smodel, t)) # Zygote returns a tuple
return sum(abs2, ŷ .- cos.(t))
end
model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
Dense(12 => 1))
ps, st = Lux.setup(Xoshiro(0), model)
t = rand(Xoshiro(0), Float32, 1, 16)
```

Now the moment of truth:

```@example nested_ad
_, ∂t, ∂ps, _ = Zygote.gradient(loss_function2, model, t, ps, st)
```

Boom that worked! Let's verify the gradient using forward diff:

```@example nested_ad
∂t_fd = ForwardDiff.gradient(t -> loss_function2(model, t, ps, st), t)
∂ps_fd = ForwardDiff.gradient(ps -> loss_function2(model, t, ps, st), ComponentArray(ps))
println("∞-norm(∂t - ∂t_fd): ", norm(∂t .- ∂t_fd, Inf))
println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf))
nothing; # hide
```
115 changes: 90 additions & 25 deletions ext/LuxForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ const CRC = ChainRulesCore

@inline Lux._is_extension_loaded(::Val{:ForwardDiff}) = true

@inline __partials(::Type{Tag}, x::AbstractArray, i) where {Tag} = ForwardDiff.partials.(
@inline Lux.__partials(::Type{Tag}, x::AbstractArray, i) where {Tag} = ForwardDiff.partials.(
Tag, x, i)
@inline __partials(::Type{Tag}, x::Tuple, i) where {Tag} = map(
@closure(xᵢ->__partials(Tag, xᵢ, i)), x)
@inline __partials(::Type{Tag}, x::NamedTuple{F}, i) where {Tag, F} = NamedTuple{F}(map(
@closure(xᵢ->__partials(Tag, xᵢ, i)), values(x)))
@inline __partials(::Type{Tag}, x::CRC.AbstractTangent, i) where {Tag} = __partials(
@inline Lux.__partials(::Type{Tag}, x::Tuple, i) where {Tag} = map(
@closure(xᵢ->Lux.__partials(Tag, xᵢ, i)), x)
@inline Lux.__partials(::Type{Tag}, x::NamedTuple{F}, i) where {Tag, F} = NamedTuple{F}(map(
@closure(xᵢ->Lux.__partials(Tag, xᵢ, i)), values(x)))
@inline Lux.__partials(::Type{Tag}, x::CRC.AbstractTangent, i) where {Tag} = Lux.__partials(
Tag, CRC.backing(x), i)
@inline __partials(::Type{Tag}, x, i) where {Tag} = fmap(
@closure(xᵢ->__partials(Tag, xᵢ, i)), x)
@inline Lux.__partials(::Type{Tag}, x, i) where {Tag} = fmap(
@closure(xᵢ->Lux.__partials(Tag, xᵢ, i)), x)

# This is not a general jvp code, but rather meant to be efficient for nested AD calls
function Lux.__forwarddiff_jvp(
Expand All @@ -29,56 +29,122 @@ function Lux.__forwarddiff_jvp(
partials = ForwardDiff.Partials{1, T}.(tuple.(Δx))
x_dual = ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x)))
y_dual, ps_dual = f(x_dual, ps)
return __partials(Tag, y_dual, 1), __partials(Tag, ps_dual, 1)
return Lux.__partials(Tag, y_dual, 1), Lux.__partials(Tag, ps_dual, 1)
end

# Capture ForwardDiff.jacobian call and replace it with forward over reverse mode AD
@inline function __updated_jacobian_config(::ForwardDiff.JacobianConfig{T, V, N, D}, f::F,
@inline function __updated_forwarddiff_config(
::ForwardDiff.JacobianConfig{T, V, N, D}, f::F,
x::AbstractArray{V}) where {T, V, N, D, F}
return ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk{N}())
end

@inline function __updated_forwarddiff_config(
::ForwardDiff.GradientConfig{T, V, N, D}, f::F,
x::AbstractArray{V}) where {T, V, N, D, F}
return ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{N}())
end

# TODO: We can define multiple dispatches using meta programming to not construct these
# intermediate configs, but that is kind of a micro-optimization, so we can handle
# those later.
@inline function Lux.__internal_forwarddiff_jacobian_capture(
@inline function __internal_gradient_capture(
f::F, cfg::ForwardDiff.GradientConfig, chk::Val, x, args...) where {F}
# Here we can't really pass in the actual config because we modify the internal function
__f = @closure(x->f(x, args...))
return ForwardDiff.gradient(__f, x, __updated_forwarddiff_config(cfg, __f, x), chk)
end

@inline function ForwardDiff.gradient(
f::Base.ComposedFunction{<:Lux.StatefulLuxLayer, F}, x::AbstractArray,
cfg::ForwardDiff.GradientConfig=ForwardDiff.GradientConfig(f, x),
check::Val=Val(true)) where {F}
return __internal_gradient_capture(
@closure((x, ps)->f.outer(f.inner(x), ps)), cfg, check, x, ps)
end

@inline function ForwardDiff.gradient(
f::Base.ComposedFunction{F, <:Lux.StatefulLuxLayer}, x::AbstractArray,
cfg::ForwardDiff.GradientConfig=ForwardDiff.GradientConfig(f, x),
check::Val=Val(true)) where {F}
return __internal_gradient_capture(f, cfg, check, x, f.inner.ps)
end

@inline function ForwardDiff.gradient(f::Lux.StatefulLuxLayer, x::AbstractArray,
cfg::ForwardDiff.GradientConfig=ForwardDiff.GradientConfig(f, x),
check::Val=Val(true))
return __internal_gradient_capture(f, cfg, check, x, f.ps)
end

function CRC.rrule(
cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__internal_gradient_capture),
f::F, jc_cfg::ForwardDiff.GradientConfig, chk::Val, x::AbstractArray, ps) where {F}
if DISABLE_AUTOMATIC_NESTED_AD_SWITCH
y, pb_f = CRC.rrule_via_ad(cfg, ForwardDiff.gradient, Base.Fix2(f, ps), x)
∇internal_jacobian_capture_noswitch = Δ -> begin
@warn "Nested AD switch is disabled for `ForwardDiff.gradient`. If used with \
an outer `Zygote.gradient` call, the gradients wrt parameters `ps` will \
be dropped. Enable nested AD switching to get the correct (full) \
gradients." maxlog=1
_, _, ∂x = pb_f(Δ)
return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(),
CRC.NoTangent(), ∂x, CRC.NoTangent())
end
return y, ∇internal_jacobian_capture_noswitch
end

g = __internal_gradient_capture(f, jc_cfg, chk, x, ps)
∇internal_gradient_capture = Δ_ -> begin
(Δ_ isa CRC.NoTangent || Δ_ isa CRC.ZeroTangent) &&
return ntuple(Returns(CRC.NoTangent()), 6)

Δ = reshape(CRC.unthunk(Δ_), size(x))
∂x, ∂ps = Lux.__forwarddiff_jvp(x, Δ, ps) do x, ps
y, pb_f = CRC.rrule_via_ad(cfg, f, x, ps)
return pb_f(one(y))[2:3]
end
return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), ∂x, ∂ps)
end
return g, ∇internal_gradient_capture
end

@inline function __internal_jacobian_capture(
f::F, cfg::ForwardDiff.JacobianConfig, chk::Val, x, args...) where {F}
# Here we can't really pass in the actual config because we modify the internal function
__f = @closure(x->f(x, args...))
cfg_new = __updated_jacobian_config(cfg, __f, x)
return ForwardDiff.jacobian(__f, x, cfg_new, chk)
return ForwardDiff.jacobian(__f, x, __updated_forwarddiff_config(cfg, __f, x), chk)
end

@inline function ForwardDiff.jacobian(
f::Base.ComposedFunction{<:Lux.StatefulLuxLayer, F}, x::AbstractArray,
cfg::ForwardDiff.JacobianConfig=ForwardDiff.JacobianConfig(f, x),
check::Val=Val(true)) where {F}
return Lux.__internal_forwarddiff_jacobian_capture(
return __internal_jacobian_capture(
@closure((x, ps)->f.outer(f.inner(x), ps)), cfg, check, x, ps)
end

@inline function ForwardDiff.jacobian(
f::Base.ComposedFunction{F, <:Lux.StatefulLuxLayer}, x::AbstractArray,
cfg::ForwardDiff.JacobianConfig=ForwardDiff.JacobianConfig(f, x),
check::Val=Val(true)) where {F}
return Lux.__internal_forwarddiff_jacobian_capture(f, cfg, check, x, f.inner.ps)
return __internal_jacobian_capture(f, cfg, check, x, f.inner.ps)
end

@inline function ForwardDiff.jacobian(f::Lux.StatefulLuxLayer, x::AbstractArray,
cfg::ForwardDiff.JacobianConfig=ForwardDiff.JacobianConfig(f, x),
check::Val=Val(true))
return Lux.__internal_forwarddiff_jacobian_capture(f, cfg, check, x, f.ps)
return __internal_jacobian_capture(f, cfg, check, x, f.ps)
end

function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(Lux.__internal_forwarddiff_jacobian_capture), f::F,
jc_cfg::ForwardDiff.JacobianConfig, chk::Val, x::AbstractArray, ps) where {F}
function CRC.rrule(
cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__internal_jacobian_capture),
f::F, jc_cfg::ForwardDiff.JacobianConfig, chk::Val, x::AbstractArray, ps) where {F}
if DISABLE_AUTOMATIC_NESTED_AD_SWITCH
y, pb_f = CRC.rrule_via_ad(cfg, ForwardDiff.jacobian, Base.Fix2(f, ps), x)
∇internal_jacobian_capture_noswitch = Δ -> begin
@warn "Nested AD switch is disabled for `ForwardDiff.jacobian`. If used with an \
outer `Zygote.gradient` call, the gradients wrt parameters `ps` will be \
dropped. Enable nested AD switching to get the correct (full) \
@warn "Nested AD switch is disabled for `ForwardDiff.jacobian`. If used with \
an outer `Zygote.gradient` call, the gradients wrt parameters `ps` will \
be dropped. Enable nested AD switching to get the correct (full) \
gradients." maxlog=1
_, _, ∂x = pb_f(Δ)
return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(),
Expand All @@ -87,7 +153,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
return y, ∇internal_jacobian_capture_noswitch
end

J = Lux.__internal_forwarddiff_jacobian_capture(f, jc_cfg, chk, x, ps)
J = __internal_jacobian_capture(f, jc_cfg, chk, x, ps)

∇internal_jacobian_capture = Δ_ -> begin
(Δ_ isa CRC.NoTangent || Δ_ isa CRC.ZeroTangent) &&
Expand All @@ -100,8 +166,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
y, pb_f = CRC.rrule_via_ad(cfg, __f, x, ps)
return pb_f(one(y))[2:3]
end
∂xᵢ, ∂psᵢ = Lux.__forwarddiff_jvp(__gradient_fn, x, reshape(Δᵢ, size(x)), ps)
return ∂xᵢ, ∂psᵢ
return Lux.__forwarddiff_jvp(__gradient_fn, x, reshape(Δᵢ, size(x)), ps)
end
return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), ∂x, ∂ps)
end
Expand Down
Loading

0 comments on commit 131fa74

Please sign in to comment.