Skip to content

Commit

Permalink
fix: more BN test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 20, 2024
1 parent c7eed1a commit 8d8365c
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 16 deletions.
1 change: 1 addition & 0 deletions lib/LuxLib/src/impl/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ reshape_norm_dims(y, x) = reshape(x, get_norm_reshape_dims(size(y), length(x)))
end

CRC.@non_differentiable get_norm_reshape_dims(::Any...)
EnzymeRules.inactive(::typeof(get_norm_reshape_dims), ::Any...) = true

# Entry Points
## InstanceNorm
Expand Down
23 changes: 10 additions & 13 deletions lib/LuxLib/test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ is_training(::Val{training}) where {training} = training

sumabs2first(f::F, args...) where {F} = sum(abs2, first(f(args...)))

function run_batchnorm_testing(
gen_f, T, sz, training, affine, track_stats, act, aType, mode, ongpu)
function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act, aType)
epsilon = eps(T)^(5 // 7)
x, scale, bias, rm, rv = setup_batchnorm(gen_f, aType, T, sz; track_stats, affine)

Expand Down Expand Up @@ -81,12 +80,10 @@ function run_batchnorm_testing(
@test size(nt.running_var) == (size(x, length(sz) - 1),)
end

if is_training(training) && affine
skip_backends = []
act === relu && push!(skip_backends, AutoFiniteDiff())

if is_training(training)
@test_gradients(sumabs2first, batchnorm, x, scale, bias, Constant(rm),
Constant(rv), training, act, T(0.9), epsilon; atol, rtol, skip_backends)
Constant(rv), training, act, T(0.9), epsilon; atol, rtol,
enzyme_set_runtime_activity=true)
end

if anonact !== act
Expand All @@ -100,7 +97,7 @@ end
const ALL_TEST_CONFIGS = Iterators.product(
[Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)),
(Val(true), Val(false)), (true, false), (true, false),
(identity, relu, tanh_fast, sigmoid_fast, anonact))
(identity, sigmoid_fast, anonact))

const TEST_BLOCKS = collect(Iterators.partition(
ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5)))
Expand All @@ -115,7 +112,7 @@ end
@testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[1]
!fp64 && T == Float64 && continue
run_batchnorm_testing(generate_fixed_array, T, sz, training,
affine, track_stats, act, aType, mode, ongpu)
affine, track_stats, act, aType)
end
end
end
Expand All @@ -126,7 +123,7 @@ end
@testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[2]
!fp64 && T == Float64 && continue
run_batchnorm_testing(generate_fixed_array, T, sz, training,
affine, track_stats, act, aType, mode, ongpu)
affine, track_stats, act, aType)
end
end
end
Expand All @@ -137,7 +134,7 @@ end
@testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[3]
!fp64 && T == Float64 && continue
run_batchnorm_testing(generate_fixed_array, T, sz, training,
affine, track_stats, act, aType, mode, ongpu)
affine, track_stats, act, aType)
end
end
end
Expand All @@ -148,7 +145,7 @@ end
@testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[4]
!fp64 && T == Float64 && continue
run_batchnorm_testing(generate_fixed_array, T, sz, training,
affine, track_stats, act, aType, mode, ongpu)
affine, track_stats, act, aType)
end
end
end
Expand All @@ -159,7 +156,7 @@ end
@testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[5]
!fp64 && T == Float64 && continue
run_batchnorm_testing(generate_fixed_array, T, sz, training,
affine, track_stats, act, aType, mode, ongpu)
affine, track_stats, act, aType)
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/test/normalization/groupnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ const ALL_TEST_CONFIGS = Iterators.product([Float32, Float64],
),
(2, 3),
(true, false),
(identity, relu, tanh_fast, sigmoid_fast, anonact))
(identity, sigmoid_fast, anonact))

const TEST_BLOCKS = collect(Iterators.partition(
ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5)))
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ end

const ALL_TEST_CONFIGS = Iterators.product(
[Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)),
(Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact))
(Val(true), Val(false)), (identity, sigmoid_fast, anonact))

const TEST_BLOCKS = collect(Iterators.partition(
ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5)))
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/test/normalization/layernorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ const ALL_TEST_CONFIGS = Any[]
for T in (Float32, Float64),
x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)),
affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])),
act in (identity, relu, tanh_fast, sigmoid_fast, anonact)
act in (identity, sigmoid_fast, anonact)

push!(ALL_TEST_CONFIGS, (T, x_shape, affine_shape, act))
end
Expand Down

0 comments on commit 8d8365c

Please sign in to comment.