Skip to content

Commit

Permalink
fix: more enzyme support
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 18, 2024
1 parent fe9ac31 commit fc24816
Show file tree
Hide file tree
Showing 14 changed files with 80 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ ComponentArrays = "0.15.18"
ConcreteStructs = "0.2.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.15"
EnzymeCore = "0.8.5"
EnzymeCore = "0.8.6"
FastClosures = "0.3.2"
Flux = "0.14.25"
ForwardDiff = "0.10.36"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ArrayInterface = "7.9"
ChainRulesCore = "1.24"
Compat = "4.16"
DispatchDoctor = "0.4.10"
EnzymeCore = "0.8.5"
EnzymeCore = "0.8.6"
Functors = "0.5"
MLDataDevices = "1.6"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxCore/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Aqua = "0.8.7"
EnzymeCore = "0.8.5"
EnzymeCore = "0.8.6"
ExplicitImports = "1.9.0"
Functors = "0.5"
MLDataDevices = "1.6"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Compat = "4.16"
CpuId = "0.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.15"
EnzymeCore = "0.8.5"
EnzymeCore = "0.8.6"
FastClosures = "0.3.2"
ForwardDiff = "0.10.36"
Hwloc = "3.2"
Expand Down
40 changes: 31 additions & 9 deletions lib/LuxLib/src/impl/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,29 @@ function batchnorm_affine_normalize_internal!(
end

function compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ)
if γ === nothing && β === nothing
@simd ivdep for J in eachindex(γ′, β′, μ, σ²)
@fastmath @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ))
@fastmath @inbounds β′[J] = -μ[J] * γ′[J]
if Utils.within_enzyme_autodiff()
if γ === nothing && β === nothing
for J in eachindex(γ′, β′, μ, σ²)
@inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ))
@inbounds β′[J] = -μ[J] * γ′[J]
end
else
for J in eachindex(γ′, β′, γ, β, μ, σ²)
@inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ)
@inbounds β′[J] = β[J] - μ[J] * γ′[J]
end
end
else
@simd ivdep for J in eachindex(γ′, β′, γ, β, μ, σ²)
@fastmath @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ)
@fastmath @inbounds β′[J] = β[J] - μ[J] * γ′[J]
if γ === nothing && β === nothing
@simd ivdep for J in eachindex(γ′, β′, μ, σ²)
@fastmath @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ))
@fastmath @inbounds β′[J] = -μ[J] * γ′[J]
end
else
@simd ivdep for J in eachindex(γ′, β′, γ, β, μ, σ²)
@fastmath @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ)
@fastmath @inbounds β′[J] = β[J] - μ[J] * γ′[J]
end
end
end
end
Expand All @@ -115,7 +129,11 @@ function apply_batchnorm_scale_bias_act_cpu!(
if size(y, 1) == 1
apply_batchnorm_scale_bias_act_2d_serial_cpu!(y, γ′, β′, x, σ)
else
apply_batchnorm_scale_bias_act_3d_threaded_cpu!(y, γ′, β′, x, σ)
if Utils.within_enzyme_autodiff()
apply_batchnorm_scale_bias_act_3d_serial_cpu!(y, γ′, β′, x, σ)
else
apply_batchnorm_scale_bias_act_3d_threaded_cpu!(y, γ′, β′, x, σ)
end
end
end

Expand Down Expand Up @@ -160,7 +178,11 @@ function apply_batchnorm_scale_bias_cpu!(y::AbstractArray{yT, 3}, γ′::Abstrac
if size(y, 1) == 1
apply_batchnorm_scale_bias_2d_serial_cpu!(y, γ′, β′, x)
else
apply_batchnorm_scale_bias_3d_threaded_cpu!(y, γ′, β′, x)
if Utils.within_enzyme_autodiff()
apply_batchnorm_scale_bias_3d_serial_cpu!(y, γ′, β′, x)
else
apply_batchnorm_scale_bias_3d_threaded_cpu!(y, γ′, β′, x)
end
end
end

Expand Down
5 changes: 2 additions & 3 deletions lib/LuxLib/src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ using Hwloc: Hwloc
using Static: static, False, True

using ..LuxLib: DISABLE_LOOP_VECTORIZATION
using ..Utils: is_extension_loaded, safe_minimum, unsafe_known
using ..Utils: is_extension_loaded, safe_minimum, unsafe_known, within_enzyme_autodiff

const CRC = ChainRulesCore

Expand Down Expand Up @@ -136,8 +136,7 @@ CRC.@non_differentiable explicit_blas_loaded()
use_octavian() = False()
else
function use_octavian()
unsafe_known(is_extension_loaded(Val(:Enzyme))) && EnzymeCore.within_autodiff() &&
return False()
within_enzyme_autodiff() && return False()
return is_extension_loaded(Val(:Octavian)) & is_x86_64() &
(INTEL_HARDWARE | AMD_RYZEN_HARDWARE)
end
Expand Down
8 changes: 6 additions & 2 deletions lib/LuxLib/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ within_autodiff(::AbstractArray{<:ForwardDiff.Dual}) = True()

CRC.rrule(::typeof(within_autodiff), x) = True(), _ -> (∂∅, ∂∅)

function within_enzyme_autodiff()
unsafe_known(is_extension_loaded(Val(:Enzyme))) && return EnzymeCore.within_autodiff()
return false
end

static_training_mode(::Nothing, args...) = within_autodiff_vararg(args...)

function static_training_mode(
Expand Down Expand Up @@ -330,8 +335,7 @@ CRC.@non_differentiable static_training_mode_check(::Any...)
else
@inline function can_loopvec_args(args...)
# Avoid loop vectorization inside Enzyme autodiff calls
unsafe_known(is_extension_loaded(Val(:Enzyme))) && EnzymeCore.within_autodiff() &&
return false
within_enzyme_autodiff() && return false
return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...)
end
end
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ BenchmarkTools = "1.5"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.18"
Enzyme = "0.13.15"
EnzymeCore = "0.8.5"
EnzymeCore = "0.8.6"
ExplicitImports = "1.9.0"
ForwardDiff = "0.10.36"
Hwloc = "3.2"
Expand Down
6 changes: 2 additions & 4 deletions test/enzyme_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ const MODELS_LIST = [
(Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)),
# XXX: https://github.com/EnzymeAD/Enzyme.jl/issues/2105
# (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)),
# XXX: https://github.com/LuxDL/Lux.jl/issues/1024
# (Bilinear((2, 2) => 3), randn(Float32, 2, 3)),
(Bilinear((2, 2) => 3), randn(Float32, 2, 3)),
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)),
(ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)),
(StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)),
Expand All @@ -66,8 +65,7 @@ const MODELS_LIST = [
(Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)),
(Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)),
(Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)),
# XXX: Recent Enzyme release breaks this
# (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)),
(Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)),
(Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)),
(Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)),
]
Expand Down
18 changes: 13 additions & 5 deletions test/helpers/loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,10 @@ end

@test @inferred(Zygote.gradient(celoss, ŷ, y)) isa Any

# Failure only on CI
__f = Base.Fix2(celoss, y)
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(__f, ŷ; atol=1.0f-3,
rtol=1.0f-3, skip_backends=VERSION v"1.11-" ? [AutoEnzyme()] : [])
end

@testset "Logit CrossEntropyLoss" begin
Expand All @@ -175,8 +177,10 @@ end

@test @inferred(Zygote.gradient(logitceloss, logŷ, y)) isa Any

# Failure only on CI
__f = Base.Fix2(logitceloss, y)
@test_gradients(__f, logŷ; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(__f, logŷ; atol=1.0f-3,
rtol=1.0f-3, skip_backends=VERSION v"1.11-" ? [AutoEnzyme()] : [])
end

logŷ, y = randn(3) |> aType, rand(3) |> aType
Expand Down Expand Up @@ -279,8 +283,10 @@ end
else
ongpu ? [AutoTracker()] : []
end
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoFiniteDiff()], broken_backends)
skip_backends = VERSION v"1.11-" ? Any[AutoEnzyme()] : []
push!(skip_backends, AutoFiniteDiff())
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3, skip_backends,
broken_backends)
end
end
end
Expand All @@ -301,8 +307,10 @@ end
@jet KLDivergenceLoss()(ŷ, y)
@test @inferred(Zygote.gradient(KLDivergenceLoss(), ŷ, y)) isa Any

# Failure only on CI
__f = Base.Fix2(KLDivergenceLoss(), y)
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(__f, ŷ; atol=1.0f-3,
rtol=1.0f-3, skip_backends=VERSION v"1.11-" ? [AutoEnzyme()] : [])
end

@testset "HingeLoss" begin
Expand Down
18 changes: 6 additions & 12 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,7 @@ end

@test size(layer(x, ps, st)[1]) == (3, 1)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoEnzyme()])
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)

d = Dense(2 => 2)
display(d)
Expand All @@ -269,8 +268,7 @@ end

@test size(layer(x, ps, st)[1]) == (3, 1)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoEnzyme()])
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)

d = Dense(2 => 3)
display(d)
Expand All @@ -283,8 +281,7 @@ end

@test size(layer(x, ps, st)[1]) == (5, 7, 11)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoEnzyme()])
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
end

@testset "Two-streams zero sum" begin
Expand All @@ -299,8 +296,7 @@ end

@test LuxCore.outputsize(layer, (x, y), rng) == (3,)
@jet layer((x, y), ps, st)
@test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3,
rtol=1.0f-3, skip_backends=[AutoEnzyme()])
@test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3)
end

@testset "Inner interactions" begin
Expand All @@ -311,8 +307,7 @@ end

@test size(layer(x, ps, st)[1]) == (3, 1)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoEnzyme()])
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)

x = randn(Float32, 2, 1) |> aType
layer = Bilinear(2 => 3)
Expand All @@ -321,8 +316,7 @@ end

@test size(layer(x, ps, st)[1]) == (3, 1)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoEnzyme()])
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
end
end
end
Expand Down
12 changes: 5 additions & 7 deletions test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,8 @@ end
scales = (nothing, 2, (2, 1))

@testset for umode in modes, xsize in sizes, scale in scales
if !xor(isnothing(xsize), isnothing(scale))
continue
end
xor(isnothing(xsize), isnothing(scale)) || continue

layer = Upsample(umode; size=xsize, scale=scale)
display(layer)
ps, st = Lux.setup(rng, layer) |> dev
Expand All @@ -242,9 +241,8 @@ end
scales = (nothing, 2, (2, 1, 1), (2, 2, 1))

@testset for umode in modes, xsize in sizes, scale in scales
if !xor(isnothing(xsize), isnothing(scale))
continue
end
xor(isnothing(xsize), isnothing(scale)) || continue

layer = Upsample(umode; size=xsize, scale=scale)
display(layer)
ps, st = Lux.setup(rng, layer) |> dev
Expand All @@ -263,7 +261,7 @@ end

broken_backends = Any[AutoTracker()]
umode == :nearest || push!(broken_backends, AutoReverseDiff())
if VERSION < v"1.11-"
if VERSION < v"1.11-" && umode == :nearest
push!(broken_backends, AutoEnzyme())
end
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
Expand Down
8 changes: 5 additions & 3 deletions test/layers/normalize_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
x_ = m(x, ps, st_)[1] |> CPUDevice()
@test x_[1](1 .- 0.3) / sqrt(1.3) atol=1.0e-5

broken_backends = VERSION v"1.11-" ? [AutoEnzyme()] : []

@jet m(x, ps, st)
@test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3,
rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()])
rtol=1.0f-3, skip_backends=[AutoFiniteDiff()], broken_backends)

@testset for affine in (true, false)
m = BatchNorm(2; affine, track_stats=false)
Expand All @@ -54,7 +56,7 @@

@jet m(x, ps, Lux.testmode(st))
@test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3,
rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()])
rtol=1.0f-3, skip_backends=[AutoFiniteDiff()], broken_backends)

# with activation function
m = BatchNorm(2, sigmoid; affine)
Expand All @@ -68,7 +70,7 @@
sigmoid.((x .- st_.running_mean) ./ sqrt.(st_.running_var .+ m.epsilon))
@jet m(x, ps, Lux.testmode(st))
@test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3,
rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()])
rtol=1.0f-3, skip_backends=[AutoFiniteDiff()], broken_backends)

m = BatchNorm(32; affine)
x = randn(Float32, 416, 416, 32, 1) |> aType
Expand Down
6 changes: 5 additions & 1 deletion test/layers/recurrent_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ end
@test !hasproperty(ps, :hidden_state)
end

@test_gradients(loss_loop, rnncell, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
# Failure only on CI
skip_backends = VERSION v"1.11-" && use_bias && act === identity ?
[AutoEnzyme()] : []
@test_gradients(loss_loop, rnncell, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)
end
end

Expand Down

0 comments on commit fc24816

Please sign in to comment.