From 40579d1863a9c45d2329c43ec6c8ce9b6c3544dd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 17 Nov 2024 21:58:10 -0500 Subject: [PATCH] fix: avoid LV or Octavian with Enzyme --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/traits.jl | 5 ++- lib/LuxLib/src/utils.jl | 3 ++ lib/LuxTestUtils/src/utils.jl | 2 +- test/helpers/size_propagator_test.jl | 40 ----------------------- test/layers/basic_tests.jl | 7 +++-- test/layers/normalize_tests.jl | 47 ++++++++++++++-------------- 7 files changed, 36 insertions(+), 70 deletions(-) delete mode 100644 test/helpers/size_propagator_test.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index f9365f311f..5cb751d7db 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.8" +version = "1.3.9" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 29d3dc1e0c..a9164f2c4d 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -77,11 +77,12 @@ end module System using ChainRulesCore: ChainRulesCore +using EnzymeCore: EnzymeCore using Hwloc: Hwloc using Static: static, False, True using ..LuxLib: DISABLE_LOOP_VECTORIZATION -using ..Utils: is_extension_loaded, safe_minimum +using ..Utils: is_extension_loaded, safe_minimum, unsafe_known const CRC = ChainRulesCore @@ -135,6 +136,8 @@ CRC.@non_differentiable explicit_blas_loaded() use_octavian() = False() else function use_octavian() + unsafe_known(is_extension_loaded(Val(:Enzyme))) && EnzymeCore.within_autodiff() && + return False() return is_extension_loaded(Val(:Octavian)) & is_x86_64() & (INTEL_HARDWARE | AMD_RYZEN_HARDWARE) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 0104457c79..14748d67f8 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -329,6 +329,9 @@ CRC.@non_differentiable static_training_mode_check(::Any...) @inline can_loopvec_args(args...) = false else @inline function can_loopvec_args(args...) + # Avoid loop vectorization inside Enzyme autodiff calls + unsafe_known(is_extension_loaded(Val(:Enzyme))) && EnzymeCore.within_autodiff() && + return false return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...) end end diff --git a/lib/LuxTestUtils/src/utils.jl b/lib/LuxTestUtils/src/utils.jl index 5233e7bddd..1da442eb3f 100644 --- a/lib/LuxTestUtils/src/utils.jl +++ b/lib/LuxTestUtils/src/utils.jl @@ -65,7 +65,7 @@ function aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N}) where {N} y = length(x) > 1 ? reduce(vcat, x) : reduce(vcat, [x[1], x[1]])[1:1] return reshape(y, size(x)) end -aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where {N} = Tracker.collect(x) +aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal, N}) where {N} = Tracker.collect(x) function needs_gradient(y) leaves = Functors.fleaves(y) diff --git a/test/helpers/size_propagator_test.jl b/test/helpers/size_propagator_test.jl deleted file mode 100644 index 7c41e150f3..0000000000 --- a/test/helpers/size_propagator_test.jl +++ /dev/null @@ -1,40 +0,0 @@ -@testitem "Size Propagator" setup=[SharedTestSetup] tags=[:misc] begin - rng = StableRNG(12345) - - @testset "Simple Chain (LeNet)" begin - lenet = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)), - Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3), - Dense(256 => 120, relu), Dense(120 => 84, relu), Dense(84 => 10)) - - for x in (randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) - @test Lux.outputsize(lenet, x, rng) == (10,) - end - end - - @testset "Chain with BatchNorm" begin - lenet = Chain(Conv((5, 5), 1 => 6, relu), BatchNorm(6, relu), MaxPool((2, 2)), - Conv((5, 5), 6 => 16, relu), BatchNorm(16, relu), - MaxPool((2, 2)), FlattenLayer(3), Dense(256 => 120, relu), - BatchNorm(120, relu), Dense(120 => 84, relu), Dropout(0.5f0), - BatchNorm(84, relu), Dense(84 => 10), BatchNorm(10, relu)) - - for x in (randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) - @test Lux.outputsize(lenet, x, rng) == (10,) - end - end - - norm_layer = [ - (BatchNorm(3, relu), [randn(rng, Float32, 4, 4, 3, 2), randn(rng, Float32, 3, 3)]), - (GroupNorm(6, 3, relu), - [randn(rng, Float32, 4, 4, 6, 2), randn(rng, Float32, 6, 3)]), - (InstanceNorm(3, relu), - [randn(rng, Float32, 4, 4, 3, 2), randn(rng, Float32, 4, 3, 2)]), - (LayerNorm((2, 1, 3), relu), - [randn(rng, Float32, 2, 4, 3, 2), randn(rng, Float32, 2, 1, 3, 3)])] - - @testset "Normalization: $(nameof(typeof(layer)))" for (layer, xs) in norm_layer - for x in xs - @test Lux.outputsize(layer, x, rng) == size(x)[1:(end - 1)] - end - end -end diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 7da8995025..05f1ec6592 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -63,7 +63,8 @@ @jet layer(x, ps, st) x = randn(rng, 6, 3) |> aType - @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + broken_backends=[AutoReverseDiff()]) end @testset "SelectDim Layer" begin @@ -299,8 +300,8 @@ end @test LuxCore.outputsize(layer, (x, y), rng) == (3,) @jet layer((x, y), ps, st) - @test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3, - skip_backends=[AutoEnzyme()]) + @test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, + rtol=1.0f-3, skip_backends=[AutoEnzyme()]) end @testset "Inner interactions" begin diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl index f3c35ed43f..21da1f7462 100644 --- a/test/layers/normalize_tests.jl +++ b/test/layers/normalize_tests.jl @@ -43,19 +43,18 @@ @test x_[1]≈(1 .- 0.3) / sqrt(1.3) atol=1.0e-5 @jet m(x, ps, st) - __f = (x, ps) -> sum(first(m(x, ps, st))) - @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, - skip_backends=[AutoFiniteDiff()]) + @test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3, + rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()]) - for affine in (true, false) + @testset for affine in (true, false) m = BatchNorm(2; affine, track_stats=false) x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] |> aType display(m) ps, st = Lux.setup(rng, m) |> dev @jet m(x, ps, Lux.testmode(st)) - @test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3, rtol=1.0f-3, - skip_backends=[AutoFiniteDiff()]) + @test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3, + rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()]) # with activation function m = BatchNorm(2, sigmoid; affine) @@ -68,16 +67,8 @@ @test y ≈ sigmoid.((x .- st_.running_mean) ./ sqrt.(st_.running_var .+ m.epsilon)) @jet m(x, ps, Lux.testmode(st)) - - if affine - __f = (x, ps) -> sum(first(m(x, ps, st))) - @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, - skip_backends=[AutoFiniteDiff()]) - else - __f = x -> sum(first(m(x, ps, st))) - @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - skip_backends=[AutoFiniteDiff()]) - end + @test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3, + rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()]) m = BatchNorm(32; affine) x = randn(Float32, 416, 416, 32, 1) |> aType @@ -170,7 +161,8 @@ end x = randn(rng, Float32, 3, 3, 3, 1) |> aType @jet wn(x, ps, st) - @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + broken_backends=[AutoEnzyme()]) wn = WeightNorm(c, (:weight,)) display(wn) @@ -178,7 +170,8 @@ end x = randn(rng, Float32, 3, 3, 3, 1) |> aType @jet wn(x, ps, st) - @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + broken_backends=[AutoEnzyme()]) wn = WeightNorm(c, (:weight, :bias), (2, 2)) display(wn) @@ -186,7 +179,8 @@ end x = randn(rng, Float32, 3, 3, 3, 1) |> aType @jet wn(x, ps, st) - @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + broken_backends=[AutoEnzyme()]) wn = WeightNorm(c, (:weight,), (2,)) display(wn) @@ -194,7 +188,8 @@ end x = randn(rng, Float32, 3, 3, 3, 1) |> aType @jet wn(x, ps, st) - @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + broken_backends=[AutoEnzyme()]) end @testset "Dense" begin @@ -206,7 +201,8 @@ end x = randn(rng, Float32, 3, 1) |> aType @jet wn(x, ps, st) - @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + broken_backends=[AutoEnzyme()]) wn = WeightNorm(d, (:weight,)) display(wn) @@ -214,7 +210,8 @@ end x = randn(rng, Float32, 3, 1) |> aType @jet wn(x, ps, st) - @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + broken_backends=[AutoEnzyme()]) wn = WeightNorm(d, (:weight, :bias), (2, 2)) display(wn) @@ -222,7 +219,8 @@ end x = randn(rng, Float32, 3, 1) |> aType @jet wn(x, ps, st) - @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + broken_backends=[AutoEnzyme()]) wn = WeightNorm(d, (:weight,), (2,)) display(wn) @@ -230,7 +228,8 @@ end x = randn(rng, Float32, 3, 1) |> aType @jet wn(x, ps, st) - @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + broken_backends=[AutoEnzyme()]) end # See https://github.com/LuxDL/Lux.jl/issues/95