From 000b6914557a6d6c426a0f90ec90b517f5c5f567 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 22 Sep 2024 20:35:12 -0400 Subject: [PATCH] feat: add forward mode batched enzyme jacobian --- ext/LuxEnzymeExt/LuxEnzymeExt.jl | 20 ++++++++--- ext/LuxEnzymeExt/batched_autodiff.jl | 50 ++++++++++++++++++++++++++++ src/Lux.jl | 1 + src/autodiff/api.jl | 38 ++++++++++++--------- 4 files changed, 90 insertions(+), 19 deletions(-) create mode 100644 ext/LuxEnzymeExt/batched_autodiff.jl diff --git a/ext/LuxEnzymeExt/LuxEnzymeExt.jl b/ext/LuxEnzymeExt/LuxEnzymeExt.jl index 1c6b5cc887..512c31740e 100644 --- a/ext/LuxEnzymeExt/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt/LuxEnzymeExt.jl @@ -1,16 +1,28 @@ module LuxEnzymeExt -using ADTypes: AutoEnzyme -using Enzyme: Enzyme, Active, Const, Duplicated +using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode +using ArgCheck: @argcheck +using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated using EnzymeCore: EnzymeCore using Functors: fmap -using Setfield: @set! -using Static: False, True +using Setfield: @set!, @set +using Static: False, True, StaticBool using Lux: Lux, Utils using Lux.Training: TrainingBackendCache, TrainState using MLDataDevices: isleaf +Lux.is_extension_loaded(::Val{:Enzyme}) = true + +normalize_backend(::StaticBool, ad::AutoEnzyme) = ad +normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Forward) +normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Reverse) + +annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f +annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f) + include("training.jl") +include("batched_autodiff.jl") + end diff --git a/ext/LuxEnzymeExt/batched_autodiff.jl b/ext/LuxEnzymeExt/batched_autodiff.jl new file mode 100644 index 0000000000..5d1d1a0c06 --- /dev/null +++ b/ext/LuxEnzymeExt/batched_autodiff.jl @@ -0,0 +1,50 @@ +function Lux.AutoDiffInternalImpl.batched_jacobian_impl( + f::F, ad::AutoEnzyme, x::AbstractArray) where {F} + backend = normalize_backend(True(), ad) + return batched_enzyme_jacobian_impl( + annotate_function(ad, f), backend, ADTypes.mode(backend), x) +end + +function batched_enzyme_jacobian_impl( + f::F, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {F} + # We need to run the function once to get the output type. Can we use ForwardWithPrimal? + y = f(x) + + @argcheck y isa AbstractArray MethodError + if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x)) + throw(AssertionError("`batched_jacobian` only supports batched outputs \ + (ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x)).")) + end + B = size(y, ndims(y)) + + J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]), + prod(size(x)[1:(end - 1)]), B) + + chunk_size = min(8, length(y) ÷ B) + partials = ntuple(_ -> zero(x), chunk_size) + + for i in 1:chunk_size:(length(x) ÷ B) + idxs = i:min(i + chunk_size - 1, length(x) ÷ B) + partials′ = make_onehot!(partials, idxs) + J_partials = only(Enzyme.autodiff(ad.mode, f, BatchDuplicated(x, partials′))) + for (idx, J_partial) in zip(idxs, J_partials) + copyto!(view(J, :, idx, :), reshape(J_partial, :, B)) + end + end + + return J +end + +function batched_enzyme_jacobian_impl( + f::F, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {F} + error("reverse mode is not supported yet") +end + +function make_onehot!(partials, idxs) + for (idx, partial) in zip(idxs, partials) + partial′ = reshape(partial, :, size(partial, ndims(partial))) + fill!(partial′, false) + fill!(view(partial′, idx, :), true) + end + return partials[1:length(idxs)] +end diff --git a/src/Lux.jl b/src/Lux.jl index 1b99492bb5..033396c3b8 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -34,6 +34,7 @@ const NAME_TYPE = Union{Nothing, String, Symbol} const Optional{T} = Union{T, Nothing} is_extension_loaded(::Val) = false +is_extension_loaded(::Val{:ForwardDiff}) = true # Preferences include("preferences.jl") diff --git a/src/autodiff/api.jl b/src/autodiff/api.jl index 6db1e38da6..1077b6710a 100644 --- a/src/autodiff/api.jl +++ b/src/autodiff/api.jl @@ -33,10 +33,7 @@ function vector_jacobian_product(::F, backend::AbstractADType, _, __) where {F} end function vector_jacobian_product(f::F, backend::AutoZygote, x, u) where {F} - if !is_extension_loaded(Val(:Zygote)) - error("`Zygote.jl` must be loaded for `vector_jacobian_product` \ - to work with `$(backend)`.") - end + assert_backend_loaded(:vector_jacobian_product, backend) return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u) end @@ -89,10 +86,11 @@ the following properties for `y = f(x)`: ## Backends & AD Packages -| Supported Backends | Packages Needed | -|:------------------ |:--------------- | -| `AutoForwardDiff` | | -| `AutoZygote` | `Zygote.jl` | +| Supported Backends | Packages Needed | Note | +|:------------------ |:--------------- |:---------------------------------------------- | +| `AutoForwardDiff` | | | +| `AutoZygote` | `Zygote.jl` | | +| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD | ## Arguments @@ -118,14 +116,24 @@ function batched_jacobian(::F, backend::AbstractADType, x::AbstractArray) where throw(ArgumentError("`batched_jacobian` is not implemented for `$(backend)`.")) end -function batched_jacobian(f::F, backend::AutoForwardDiff, x::AbstractArray) where {F} - return AutoDiffInternalImpl.batched_jacobian(f, backend, x) +for implemented_backend in (AutoForwardDiff, AutoZygote, AutoEnzyme) + @eval function batched_jacobian( + f::F, backend::$(implemented_backend), x::AbstractArray) where {F} + assert_backend_loaded(:batched_jacobian, backend) + return AutoDiffInternalImpl.batched_jacobian(f, backend, x) + end end -function batched_jacobian(f::F, backend::AutoZygote, x::AbstractArray) where {F} - if !is_extension_loaded(Val(:Zygote)) - error("`Zygote.jl` must be loaded for `batched_jacobian` to work with \ - `$(backend)`.") +function assert_backend_loaded(fname::Symbol, ad::AbstractADType) + return assert_backend_loaded(fname, ad, adtype_to_backend(ad)) +end +function assert_backend_loaded(fname::Symbol, ad::AbstractADType, backend::Val{B}) where {B} + if !is_extension_loaded(backend) + error("$(fname) with `$(ad)` requires $(B).jl to be loaded.") end - return AutoDiffInternalImpl.batched_jacobian(f, backend, x) + return end + +adtype_to_backend(::AutoEnzyme) = Val(:Enzyme) +adtype_to_backend(::AutoForwardDiff) = Val(:ForwardDiff) +adtype_to_backend(::AutoZygote) = Val(:Zygote)