diff --git a/test/common_ops/dropout_tests.jl b/test/common_ops/dropout_tests.jl index 6cf90d5f..5d3baa28 100644 --- a/test/common_ops/dropout_tests.jl +++ b/test/common_ops/dropout_tests.jl @@ -75,8 +75,7 @@ end x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(true), invp, :))) end test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : [])) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), :))) @@ -105,11 +104,8 @@ end soft_fail = T == Float16 ? Any[AutoFiniteDiff()] : [] skip_backends = length(x_shape) == 5 ? [AutoEnzyme()] : [] - broken_backends = T == Float16 && Sys.iswindows() && length(x_shape) != 5 ? - [AutoEnzyme()] : [] - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends, - broken_backends) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)))