From 3ddd9c5b5968716530378caa6191cfc3f5b81c15 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Nov 2024 08:58:54 -0500 Subject: [PATCH] fix: bypass enzyme bmm failure --- lib/LuxLib/src/impl/batched_mul.jl | 9 +++++++-- lib/LuxLib/test/normalization/batchnorm_tests.jl | 3 ++- test/enzyme_tests.jl | 3 +-- test/layers/basic_tests.jl | 3 +-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 8d91951293..37e62de674 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -61,7 +61,12 @@ function batched_matmul_cpu!(z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, batched_matmul_loopvec_impl!(z, x, y) return end - NNlib.batched_mul!(z, x, y) + if Utils.within_enzyme_autodiff() + # XXX: https://github.com/LuxDL/Lux.jl/issues/1024 + fallback_batched_matmul!(z, LoopedArrayOp(), x, y) + else + NNlib.batched_mul!(z, x, y) + end return end @@ -78,7 +83,7 @@ end function fallback_batched_matmul!( z::AbstractArray{zT, 3}, opmode, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} - @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ + @warn "Using fallback Batched Matrix Multiply routine for $(opmode) with A: size = \ $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ slow." maxlog=1 diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 1f5fb342f5..2ad299d796 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -145,6 +145,7 @@ end @test nt.running_var isa aType && length(nt.running_var) == 6 @test_gradients(sumabs2first, batchnorm, x, scale, bias, Constant(running_mean), - Constant(running_var), training, act, 0.9, 1e-5; atol=1.0f-3, rtol=1.0f-3) + Constant(running_var), Val(true), gelu, 0.9, 1e-5; atol=1.0f-3, rtol=1.0f-3, + broken_backends=[AutoEnzyme()]) end end diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl index c97c9e3e88..8fbf085fc5 100644 --- a/test/enzyme_tests.jl +++ b/test/enzyme_tests.jl @@ -44,8 +44,7 @@ const MODELS_LIST = Any[ (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), - # XXX: https://github.com/LuxDL/Lux.jl/issues/1024 - # (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), + (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 397b668d38..3adea7323d 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -242,8 +242,7 @@ end @testitem "Bilinear" setup=[SharedTestSetup] tags=[:core_layers] begin rng = StableRNG(12345) - # XXX: https://github.com/LuxDL/Lux.jl/issues/1024 - skip_backends = [AutoEnzyme()] + skip_backends = VERSION < v"1.11-" ? [AutoEnzyme()] : [] @testset "$mode" for (mode, aType, dev, ongpu) in MODES @testset "SkipConnection recombinator" begin