Skip to content

Commit

Permalink
fix: bypass enzyme bmm failure
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 21, 2024
1 parent 5e979ff commit 3ddd9c5
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
9 changes: 7 additions & 2 deletions lib/LuxLib/src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ function batched_matmul_cpu!(z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
batched_matmul_loopvec_impl!(z, x, y)
return
end
NNlib.batched_mul!(z, x, y)
if Utils.within_enzyme_autodiff()
# XXX: https://github.com/LuxDL/Lux.jl/issues/1024
fallback_batched_matmul!(z, LoopedArrayOp(), x, y)
else
NNlib.batched_mul!(z, x, y)
end
return
end

Expand All @@ -78,7 +83,7 @@ end
function fallback_batched_matmul!(
z::AbstractArray{zT, 3}, opmode, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {zT, xT, yT}
@warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \
@warn "Using fallback Batched Matrix Multiply routine for $(opmode) with A: size = \
$(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \
slow." maxlog=1

Expand Down
3 changes: 2 additions & 1 deletion lib/LuxLib/test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ end
@test nt.running_var isa aType && length(nt.running_var) == 6

@test_gradients(sumabs2first, batchnorm, x, scale, bias, Constant(running_mean),
Constant(running_var), training, act, 0.9, 1e-5; atol=1.0f-3, rtol=1.0f-3)
Constant(running_var), Val(true), gelu, 0.9, 1e-5; atol=1.0f-3, rtol=1.0f-3,
broken_backends=[AutoEnzyme()])
end
end
3 changes: 1 addition & 2 deletions test/enzyme_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ const MODELS_LIST = Any[
(Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)),
(Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)),
(Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)),
# XXX: https://github.com/LuxDL/Lux.jl/issues/1024
# (Bilinear((2, 2) => 3), randn(Float32, 2, 3)),
(Bilinear((2, 2) => 3), randn(Float32, 2, 3)),
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)),
(ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)),
(StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)),
Expand Down
3 changes: 1 addition & 2 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ end
@testitem "Bilinear" setup=[SharedTestSetup] tags=[:core_layers] begin
rng = StableRNG(12345)

# XXX: https://github.com/LuxDL/Lux.jl/issues/1024
skip_backends = [AutoEnzyme()]
skip_backends = VERSION < v"1.11-" ? [AutoEnzyme()] : []

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
@testset "SkipConnection recombinator" begin
Expand Down

0 comments on commit 3ddd9c5

Please sign in to comment.