diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index f9dafcdf0..c2c11f12a 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -129,6 +129,7 @@ reshape_norm_dims(y, x) = reshape(x, get_norm_reshape_dims(size(y), length(x))) end CRC.@non_differentiable get_norm_reshape_dims(::Any...) +EnzymeRules.inactive(::typeof(get_norm_reshape_dims), ::Any...) = true # Entry Points ## InstanceNorm diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 3e0c6db6f..b28719219 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -36,8 +36,7 @@ is_training(::Val{training}) where {training} = training sumabs2first(f::F, args...) where {F} = sum(abs2, first(f(args...))) -function run_batchnorm_testing( - gen_f, T, sz, training, affine, track_stats, act, aType, mode, ongpu) +function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act, aType) epsilon = eps(T)^(5 // 7) x, scale, bias, rm, rv = setup_batchnorm(gen_f, aType, T, sz; track_stats, affine) @@ -81,12 +80,10 @@ function run_batchnorm_testing( @test size(nt.running_var) == (size(x, length(sz) - 1),) end - if is_training(training) && affine - skip_backends = [] - act === relu && push!(skip_backends, AutoFiniteDiff()) - + if is_training(training) @test_gradients(sumabs2first, batchnorm, x, scale, bias, Constant(rm), - Constant(rv), training, act, T(0.9), epsilon; atol, rtol, skip_backends) + Constant(rv), training, act, T(0.9), epsilon; atol, rtol, + enzyme_set_runtime_activity=true) end if anonact !== act @@ -100,7 +97,7 @@ end const ALL_TEST_CONFIGS = Iterators.product( [Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), (Val(true), Val(false)), (true, false), (true, false), - (identity, relu, tanh_fast, sigmoid_fast, anonact)) + (identity, sigmoid_fast, anonact)) const TEST_BLOCKS = collect(Iterators.partition( ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) @@ -115,7 +112,7 @@ end @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[1] !fp64 && T == Float64 && continue run_batchnorm_testing(generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, ongpu) + affine, track_stats, act, aType) end end end @@ -126,7 +123,7 @@ end @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[2] !fp64 && T == Float64 && continue run_batchnorm_testing(generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, ongpu) + affine, track_stats, act, aType) end end end @@ -137,7 +134,7 @@ end @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[3] !fp64 && T == Float64 && continue run_batchnorm_testing(generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, ongpu) + affine, track_stats, act, aType) end end end @@ -148,7 +145,7 @@ end @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[4] !fp64 && T == Float64 && continue run_batchnorm_testing(generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, ongpu) + affine, track_stats, act, aType) end end end @@ -159,7 +156,7 @@ end @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[5] !fp64 && T == Float64 && continue run_batchnorm_testing(generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, ongpu) + affine, track_stats, act, aType) end end end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index cd1f9ca6b..aee725e22 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -82,7 +82,7 @@ const ALL_TEST_CONFIGS = Iterators.product([Float32, Float64], ), (2, 3), (true, false), - (identity, relu, tanh_fast, sigmoid_fast, anonact)) + (identity, sigmoid_fast, anonact)) const TEST_BLOCKS = collect(Iterators.partition( ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 5fa25dd79..ab57da3b0 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -68,7 +68,7 @@ end const ALL_TEST_CONFIGS = Iterators.product( [Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), - (Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact)) + (Val(true), Val(false)), (identity, sigmoid_fast, anonact)) const TEST_BLOCKS = collect(Iterators.partition( ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 9398d82cd..f39e8a994 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -72,7 +72,7 @@ const ALL_TEST_CONFIGS = Any[] for T in (Float32, Float64), x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), - act in (identity, relu, tanh_fast, sigmoid_fast, anonact) + act in (identity, sigmoid_fast, anonact) push!(ALL_TEST_CONFIGS, (T, x_shape, affine_shape, act)) end