diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 403bc57fb..1e1b5c58b 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -28,6 +28,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -61,5 +62,6 @@ Static = "0.8.4, 1" StaticArrays = "1.9.7" Statistics = "1.10" Test = "1.10" +TestExtras = "0.3.1" Tracker = "0.2.36" Zygote = "0.6.70" diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index 2789e7d4c..8a2a56def 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -30,18 +30,18 @@ @test eltype(y2) == T @test eltype(y3) == T - @test @inferred(apply_act(f, x)) isa Any - @test @inferred(apply_act_fast(f, x)) isa Any - @test @inferred(apply_act_fast2(f, x)) isa Any + @constinferred apply_act(f, x) + @constinferred apply_act_fast(f, x) + @constinferred apply_act_fast2(f, x) @jet apply_act_fast(f, x) @jet apply_act_fast2(f, x) - @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any + @constinferred Zygote.gradient(apply_act, f, x) if f !== lisht - @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any + @constinferred Zygote.gradient(apply_act_fast, f, x) end - @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any + @constinferred Zygote.gradient(apply_act_fast2, f, x) @test_gradients(apply_act, f, x; atol, rtol) @test_gradients(apply_act_fast, f, x; atol, rtol, skip_backends=[AutoEnzyme()]) diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index 4e0e51ced..1e932f3d9 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -5,12 +5,6 @@ bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b)) bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b)) - struct __Fix1{F, A} - f::F - act::A - end - (f::__Fix1)(x, b) = f.f(f.act, x, b) - @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$act, $T, $sz" for act in [ identity, relu, sigmoid, sigmoid_fast, softplus, @@ -27,9 +21,8 @@ y2 = bias_act_loss2(act, x, b) y3 = bias_act_loss3(act, x, b) - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 + atol = 1.0f-3 + rtol = 1.0f-3 @test y1≈y2 atol=atol rtol=rtol @test y1≈y3 atol=atol rtol=rtol @@ -37,28 +30,25 @@ @test eltype(y2) == T @test eltype(y3) == T - @test @inferred(bias_act_loss1(act, x, b)) isa Any - @test @inferred(bias_act_loss2(act, x, b)) isa Any - @test @inferred(bias_act_loss3(act, x, b)) isa Any + @constinferred bias_act_loss1(act, x, b) + @constinferred bias_act_loss2(act, x, b) + @constinferred bias_act_loss3(act, x, b) @jet bias_act_loss2(act, x, b) @jet bias_act_loss3(act, x, b) - if act !== lisht && T != Float16 - @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any - @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any + if act !== lisht + @constinferred Zygote.gradient(bias_act_loss2, act, x, b) + @constinferred Zygote.gradient(bias_act_loss3, act, x, b) end - @test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, - soft_fail=fp16 ? [AutoFiniteDiff()] : []) - @test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol, - soft_fail=fp16 ? [AutoFiniteDiff()] : []) - @test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol, - soft_fail=fp16 ? [AutoFiniteDiff()] : []) + @test_gradients(bias_act_loss1, act, x, b; atol, rtol) + @test_gradients(bias_act_loss2, act, x, b; atol, rtol) + @test_gradients(bias_act_loss3, act, x, b; atol, rtol) - ∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b) - ∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b) - ∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b) + _, ∂x1, ∂b1 = Zygote.pullback(bias_act_loss1, act, x, b) + _, ∂x2, ∂b2 = Zygote.pullback(bias_act_loss2, act, x, b) + _, ∂x3, ∂b3 = Zygote.pullback(bias_act_loss3, act, x, b) @test ∂x1≈∂x2 atol=atol rtol=rtol @test ∂x1≈∂x3 atol=atol rtol=rtol diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index b58aafcd3..ee223c09b 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -47,16 +47,15 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) if mode != "amdgpu" && activation !== anonact - @test @inferred(Zygote.gradient( - sumabs2conv, activation, weight, x, bias, cdims - )) isa Any + @test @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) isa Any else try - @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) - @test true + @test @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) isa Any catch e e isa ErrorException || rethrow() - @test_broken false + @test_broken @inferred(Zygote.gradient( + sumabs2conv, activation, weight, x, bias, cdims + )) end end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index bc4d40e55..9689a5ca8 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -117,23 +117,23 @@ end end @testitem "Fused Dense: StaticArrays" tags=[:dense] begin - using StaticArrays, NNlib + using StaticArrays, NNlib, TestExtras x = @SArray rand(2, 4) weight = @SArray rand(3, 2) bias = @SArray rand(3) - @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray + @constinferred fused_dense_bias_activation(relu, weight, x, bias) end @testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin - using JLArrays, NNlib + using JLArrays, NNlib, TestExtras x = JLArray(rand(Float32, 2, 4)) weight = JLArray(rand(Float32, 3, 2)) bias = JLArray(rand(Float32, 3)) - @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray + @constinferred fused_dense_bias_activation(relu, weight, x, bias) @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 1ec9b4618..e1de98c7e 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -10,7 +10,7 @@ x = randn(rng, T, x_shape) |> aType - @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any + @constinferred dropout(rng, x, T(0.5), Val(true), T(2), dims) y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), dims) @@ -21,10 +21,10 @@ @test rng != rng_ @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) - @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any + @constinferred dropout(rng, x, T(0.5), Val(true), T(2), dims) __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims))) - @test @inferred(Zygote.gradient(__f, x)) isa Any + @constinferred Zygote.gradient(__f, x) @test_gradients(sumabs2first, dropout, rng, x, T(0.5), Val(true), T(2), dims; atol=1.0f-3, rtol=1.0f-3) @@ -54,8 +54,7 @@ end mask = rand(T, x_shape) |> aType # Update mask - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)) isa Any + @constinferred dropout(rng, x, mask, T(0.5), Val(true), Val(true), T(2), :) y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), :) @@ -69,7 +68,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :))) - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any + @constinferred Zygote.gradient(__f, x, mask) @test_gradients(sumabs2first, dropout, rng, x, LuxTestUtils.Constant(mask), T(0.5), Val(true), Val(true), @@ -79,8 +78,7 @@ end rng, x, mask, T(0.5), Val(true), Val(true), T(2), :))) # Try using mask if possible (possible!!) - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)) isa Any + @constinferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), :) y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), :) @@ -94,7 +92,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :))) - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any + @constinferred Zygote.gradient(__f, x, mask) @test_gradients(sumabs2first, dropout, rng, x, LuxTestUtils.Constant(mask), @@ -107,8 +105,7 @@ end mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType # Testing Mode - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)) isa Any + @constinferred dropout(rng, x, mask, T(0.5), Val(false), Val(false), T(2), :) y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(false), Val(false), T(2), :) @@ -135,7 +132,7 @@ end x = randn(rng, T, x_shape) |> aType - @test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any + @constinferred alpha_dropout(rng, x, T(0.5), Val(true)) y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) @@ -146,13 +143,13 @@ end @test_broken std(y)≈std(x) atol=1.0f-2 rtol=1.0f-2 __f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true)))) - @test @inferred(Zygote.gradient(__f, x)) isa Any + @constinferred Zygote.gradient(__f, x) @test_gradients(sumabs2first, alpha_dropout, rng, x, T(0.5), Val(true); atol=1.0f-3, rtol=1.0f-3) @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any + @constinferred alpha_dropout(rng, x, T(0.5), Val(false)) y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 58b6196c1..ea8cb02e2 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module BatchNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, TestExtras function setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) x = gen_f(T, sz) |> aType @@ -89,10 +89,8 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act, end if anonact !== act - lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( - x, sc, b, rm, rv, tr, act, ϵ))) @test @inferred(Zygote.gradient( - lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any + sumabs2first, x, scale, bias, rm, rv, training, act, epsilon)) isa Any end end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index c103595f9..6302bc6dd 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -62,8 +62,8 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) @jet groupnorm(x, scale, bias, groups, act, epsilon) if anonact !== act - lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any + @test @inferred(Zygote.gradient( + sumabs2groupnorm, x, scale, bias, groups, act, epsilon)) isa Any end @test y isa aType{T, length(sz)} diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index dd999ff09..a0e9e2130 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -51,10 +51,9 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType) @jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) if anonact !== act && is_training(training) - lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm( - x, sc, b, rm, rv, Val(true), act, m, ϵ))) @test @inferred(Zygote.gradient( - lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)) isa Any + sumabs2instancenorm, x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa + Any end @test y isa aType{T, length(sz)} diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 6b82390a4..43b989615 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module LayerNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics, TestExtras using LuxTestUtils: check_approx function setup_layernorm(gen_f, aType, T, x_size, affine_shape, expand_dims::Bool=true) @@ -60,8 +60,8 @@ function run_layernorm_testing_core( soft_fail=[AutoFiniteDiff()]) if anonact !== act - lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)) isa Any + @test @inferred(Zygote.gradient( + sumabs2layernorm, x, scale, bias, act, dims, epsilon)) isa Any end end diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 77cdab470..c2072420f 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -2,7 +2,7 @@ import Reexport: @reexport using LuxLib, MLDataDevices -@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote, NNlib +@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote, NNlib, TestExtras LuxTestUtils.jet_target_modules!(["LuxLib"]) diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index e1e1f1e10..d14f76a58 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -17,6 +17,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -39,5 +40,6 @@ ReverseDiff = "1.15" SafeTestsets = "0.1" SparseArrays = "1.10" Test = "1.10" +TestExtras = "0.3.1" Tracker = "0.2.36" Zygote = "0.6.69" diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index a771ada6e..9099ceb08 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -1,4 +1,4 @@ -using MLDataDevices, Random, Test +using MLDataDevices, Random, Test, TestExtras using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -122,7 +122,7 @@ using FillArrays, Zygote # Extensions ps = (; weight=x, bias=x, d=(x, x)) return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work - @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + @constinferred Val{parameterless_type(typeof(device))} return_val(ps) end end diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index 2fce4806a..cc0b7ff23 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -1,4 +1,4 @@ -using MLDataDevices, Random, Functors, Test +using MLDataDevices, Random, Functors, Test, TestExtras using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -144,7 +144,7 @@ using FillArrays, Zygote # Extensions ps = (; weight=x, bias=x, d=(x, x)) return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work - @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + @constinferred Val{parameterless_type(typeof(device))} return_val(ps) return_val2(x) = Val(get_device(x)) @test_throws ErrorException @inferred(return_val2(ps)) diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index 2bc884553..25411ebab 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -1,4 +1,4 @@ -using MLDataDevices, Random, Test +using MLDataDevices, Random, Test, TestExtras using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -108,10 +108,10 @@ using FillArrays, Zygote # Extensions ps = (; weight=x, bias=x, d=(x, x)) return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work - @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + @constinferred Val{parameterless_type(typeof(device))} return_val(ps) return_val2(x) = Val(get_device(x)) - @test @inferred(return_val2(ps)) isa Val{get_device(x)} + @constinferred Val{get_device(x)} return_val2(ps) end end diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 65f63c9a9..4573dbe0b 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -1,4 +1,4 @@ -using Adapt, MLDataDevices, ComponentArrays, Random +using Adapt, MLDataDevices, ComponentArrays, Random, TestExtras using ArrayInterface: parameterless_type using ChainRulesTestUtils: test_rrule using ReverseDiff, Tracker, ForwardDiff @@ -148,10 +148,10 @@ end ps = (; weight=x, bias=x, d=(x, x)) return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work - @test @inferred(return_val(ps)) isa Val{typeof(cpu_device())} + @constinferred Val{typeof(cpu_device())} return_val(ps) return_val2(x) = Val(get_device(x)) - @test @inferred(return_val2(ps)) isa Val{cpu_device()} + @constinferred Val{cpu_device()} return_val2(ps) end @testset "undefined references array" begin diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index 2169869d3..355241e54 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -1,4 +1,4 @@ -using MLDataDevices, Random, Test +using MLDataDevices, Random, Test, TestExtras using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -108,10 +108,10 @@ using FillArrays, Zygote # Extensions ps = (; weight=x, bias=x, d=(x, x)) return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work - @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + @constinferred Val{parameterless_type(typeof(device))} return_val(ps) return_val2(x) = Val(get_device(x)) - @test @inferred(return_val2(ps)) isa Val{get_device(x)} + @constinferred Val{get_device(x)} return_val2(ps) end end diff --git a/lib/MLDataDevices/test/xla_tests.jl b/lib/MLDataDevices/test/xla_tests.jl index dd59af96e..2853a06cd 100644 --- a/lib/MLDataDevices/test/xla_tests.jl +++ b/lib/MLDataDevices/test/xla_tests.jl @@ -1,4 +1,4 @@ -using MLDataDevices, Random, Test +using MLDataDevices, Random, Test, TestExtras using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -108,7 +108,7 @@ using FillArrays, Zygote # Extensions ps = (; weight=x, bias=x, d=(x, x)) return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work - @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + @constinferred Val{parameterless_type(typeof(device))} return_val(ps) return_val2(x) = Val(get_device(x)) @test_throws TypeError @inferred(return_val2(ps)) diff --git a/test/Project.toml b/test/Project.toml index aca27bdbf..9466d196e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -37,6 +37,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -77,7 +78,8 @@ SimpleChains = "0.4.7" StableRNGs = "1.0.2" Static = "1" StaticArrays = "1.9" -Statistics = "1.11.1" +Statistics = "1.10" Test = "1.10" +TestExtras = "0.3.1" Tracker = "0.2.36" Zygote = "0.6.70" diff --git a/test/helpers/compact_tests.jl b/test/helpers/compact_tests.jl index 31b8fd52b..aa13daa9c 100644 --- a/test/helpers/compact_tests.jl +++ b/test/helpers/compact_tests.jl @@ -329,7 +329,7 @@ @test st_new.incr == 10 _, st_new = model(x, ps, st_new) @test st_new.incr == 100 - @test @inferred(model(x, ps, st)) isa Any + @constinferred model(x, ps, st) function ScaledDense2(; d_in=5, d_out=7, act=relu) @compact(W=randn(d_out, d_in), b=zeros(d_out), incr=1) do x @@ -349,10 +349,10 @@ _, st_new = model(x, ps, st_new) @test st_new.incr == 100 - @test @inferred(model(x, ps, st)) isa Any + @constinferred model(x, ps, st) __f = (m, x, ps, st) -> sum(abs2, first(m(x, ps, st))) - @test @inferred(Zygote.gradient(__f, model, x, ps, st)) isa Any + @constinferred Zygote.gradient(__f, model, x, ps, st) end @testset "Multiple @return" begin diff --git a/test/helpers/loss_tests.jl b/test/helpers/loss_tests.jl index ba7934314..7598620fc 100644 --- a/test/helpers/loss_tests.jl +++ b/test/helpers/loss_tests.jl @@ -9,8 +9,8 @@ ∂x2 = Zygote.gradient(LuxOps.xlogx, 2.0)[1] @test ∂x1 ≈ ∂x2 - @test @inferred(LuxOps.xlogx(2)) isa Number - @test @inferred(LuxOps.xlogx(0)) isa Number + @constinferred LuxOps.xlogx(2) + @constinferred LuxOps.xlogx(0) @jet LuxOps.xlogx(2) @test iszero(LuxOps.xlogy(0, 1)) @@ -33,13 +33,13 @@ @test_broken false end - @test @inferred(LuxOps.xlogy(2, 3)) isa Number - @test @inferred(LuxOps.xlogy(0, 1)) isa Number + @constinferred LuxOps.xlogy(2, 3) + @constinferred LuxOps.xlogy(0, 1) @jet LuxOps.xlogy(2, 3) if LuxTestUtils.ENZYME_TESTING_ENABLED - @test @inferred(Enzyme.autodiff( - Enzyme.Reverse, LuxOps.xlogy, Active, Active(2.0), Active(3.0))) isa Any + @constinferred Enzyme.autodiff( + Enzyme.Reverse, LuxOps.xlogy, Active, Active(2.0), Active(3.0)) else @test_broken false end @@ -74,7 +74,7 @@ end @test loss_sum(ŷ, y) ≈ loss_res * 4 @test loss_sum2(ŷ, y) ≈ loss_res * 4 - @test @inferred(Zygote.gradient(loss_mean, ŷ, y)) isa Any + @constinferred Zygote.gradient(loss_mean, ŷ, y) @jet loss_mean(ŷ, y) @jet loss_sum(ŷ, y) @@ -91,7 +91,11 @@ end @jet MSLELoss()(ŷ, y) - @test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any broken=ongpu + if ongpu + @constinferred_broken Zygote.gradient(MSLELoss(), ŷ, y) + else + @constinferred Zygote.gradient(MSLELoss(), ŷ, y) + end __f = Base.Fix2(MSLELoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -150,7 +154,7 @@ end @jet celoss(ŷ, y) @jet celoss_smooth(ŷ, y) - @test @inferred(Zygote.gradient(celoss, ŷ, y)) isa Any + @constinferred Zygote.gradient(celoss, ŷ, y) @test_gradients(Base.Fix2(celoss, y), ŷ; atol=1.0f-3, rtol=1.0f-3, skip_backends=VERSION ≥ v"1.11-" ? [AutoEnzyme()] : []) @@ -173,7 +177,7 @@ end @jet logitceloss(logŷ, y) @jet logitceloss_smooth(logŷ, y) - @test @inferred(Zygote.gradient(logitceloss, logŷ, y)) isa Any + @constinferred Zygote.gradient(logitceloss, logŷ, y) @test_gradients(Base.Fix2(logitceloss, y), logŷ; atol=1.0f-3, rtol=1.0f-3, skip_backends=VERSION ≥ v"1.11-" ? [AutoEnzyme()] : []) @@ -201,7 +205,7 @@ end @jet bceloss(σ.(logŷ), y) @jet bceloss_smooth(σ.(logŷ), y) - @test @inferred(Zygote.gradient(bceloss, σ.(logŷ), y)) isa Any + @constinferred Zygote.gradient(bceloss, σ.(logŷ), y) __f = Base.Fix2(bceloss, y) σlogŷ = σ.(logŷ) @@ -223,7 +227,7 @@ end @jet logitbceloss(logŷ, y) @jet logitbceloss_smooth(logŷ, y) - @test @inferred(Zygote.gradient(logitbceloss, logŷ, y)) isa Any + @constinferred Zygote.gradient(logitbceloss, logŷ, y) __f = Base.Fix2(logitbceloss, y) @test_gradients(__f, logŷ; atol=1.0f-3, rtol=1.0f-3) @@ -246,7 +250,11 @@ end @jet BinaryFocalLoss()(ŷ, y) - @test @inferred(Zygote.gradient(BinaryFocalLoss(), ŷ, y)) isa Any broken=ongpu + if ongpu + @constinferred_broken Zygote.gradient(BinaryFocalLoss(), ŷ, y) + else + @constinferred Zygote.gradient(BinaryFocalLoss(), ŷ, y) + end __f = Base.Fix2(BinaryFocalLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -270,7 +278,11 @@ end @jet FocalLoss()(ŷ, y) - @test @inferred(Zygote.gradient(FocalLoss(), ŷ, y)) isa Any broken=ongpu + if ongpu + @constinferred_broken Zygote.gradient(FocalLoss(), ŷ, y) + else + @constinferred Zygote.gradient(FocalLoss(), ŷ, y) + end __f = Base.Fix2(FocalLoss(), y) # FD will lead to out of domain errors @@ -301,7 +313,7 @@ end @test KLDivergenceLoss()(y, y) ≈ 0 @jet KLDivergenceLoss()(ŷ, y) - @test @inferred(Zygote.gradient(KLDivergenceLoss(), ŷ, y)) isa Any + @constinferred Zygote.gradient(KLDivergenceLoss(), ŷ, y) @test_gradients(Base.Fix2(KLDivergenceLoss(), y), ŷ; atol=1.0f-3, rtol=1.0f-3, skip_backends=VERSION ≥ v"1.11-" ? [AutoEnzyme()] : []) @@ -315,7 +327,7 @@ end @test Lux.HingeLoss()(y, 0.5 .* y) ≈ 0.125 @jet Lux.HingeLoss()(ŷ, y) - @test @inferred(Zygote.gradient(Lux.HingeLoss(), ŷ, y)) isa Any + @constinferred Zygote.gradient(Lux.HingeLoss(), ŷ, y) __f = Base.Fix2(Lux.HingeLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -329,7 +341,7 @@ end @test SquaredHingeLoss()(y, 0.5 .* y) ≈ 0.0625 @jet SquaredHingeLoss()(ŷ, y) - @inferred Zygote.gradient(SquaredHingeLoss(), ŷ, y) + @constinferred Zygote.gradient(SquaredHingeLoss(), ŷ, y) __f = Base.Fix2(SquaredHingeLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -343,7 +355,7 @@ end @test Lux.PoissonLoss()(y, y) ≈ 0.5044459776946685 @jet Lux.PoissonLoss()(ŷ, y) - @test @inferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y) isa Any + @constinferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y) __f = Base.Fix2(Lux.PoissonLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -357,7 +369,7 @@ end @test DiceCoeffLoss()(y, y) ≈ 0.0 @jet DiceCoeffLoss()(ŷ, y) - @test @inferred(Zygote.gradient(DiceCoeffLoss(), ŷ, y)) isa Any broken=true + @constinferred_broken Zygote.gradient(DiceCoeffLoss(), ŷ, y) __f = Base.Fix2(DiceCoeffLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3, diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index 222bb7eb2..66716c3b0 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -150,7 +150,7 @@ end tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Training.compute_gradients( + _, _, _, tstate_new = @constinferred Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) @test tstate_new.states !== tstate.states @@ -160,13 +160,12 @@ end tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Training.compute_gradients( + _, _, _, tstate_new = @constinferred Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) - @test @inferred(Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)) isa - Any + @constinferred Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new) - _, _, _, tstate_new2 = @inferred Training.compute_gradients( + _, _, _, tstate_new2 = @constinferred Training.compute_gradients( AutoEnzyme(), mse2, (x, x), tstate_new) @test hasfield(typeof(tstate_new2.cache.extras), :forward) @test hasfield(typeof(tstate_new2.cache.extras), :reverse) @@ -180,7 +179,7 @@ end tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Training.compute_gradients( + _, _, _, tstate_new = @constinferred Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) @test tstate_new.states !== tstate.states @@ -190,13 +189,12 @@ end tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Training.compute_gradients( + _, _, _, tstate_new = @constinferred Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) - @test @inferred(Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)) isa - Any + @constinferred Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new) - _, _, _, tstate_new2 = @inferred Training.compute_gradients( + _, _, _, tstate_new2 = @constinferred Training.compute_gradients( AutoEnzyme(), mse2, (x, x), tstate_new) @test hasfield(typeof(tstate_new2.cache.extras), :forward) @test hasfield(typeof(tstate_new2.cache.extras), :reverse) diff --git a/test/layers/containers_tests.jl b/test/layers/containers_tests.jl index cfa7e77d9..887f10c74 100644 --- a/test/layers/containers_tests.jl +++ b/test/layers/containers_tests.jl @@ -430,9 +430,13 @@ end st = st |> dev ps_nt = ps |> dev - @test @inferred(froggie(x, ps_nt, st)) isa Any + @constinferred froggie(x, ps_nt, st) ps_ca = ps |> ComponentArray |> dev - @test @inferred(froggie(x, ps_ca, st)) isa Any broken=ongpu + if ongpu + @constinferred_broken froggie(x, ps_ca, st) + else + @constinferred froggie(x, ps_ca, st) + end end end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index e5d853744..db3b0125d 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -8,7 +8,7 @@ using Lux, Functors using Setfield: @set using DispatchDoctor: allow_unstable @reexport using ComponentArrays, LuxCore, LuxLib, LuxTestUtils, Random, StableRNGs, Test, - Zygote, Statistics, Enzyme, LinearAlgebra, ForwardDiff + Zygote, Statistics, Enzyme, LinearAlgebra, ForwardDiff, TestExtras using MLDataDevices: default_device_rng, CPUDevice, CUDADevice, AMDGPUDevice using LuxTestUtils: check_approx using Static: True diff --git a/test/zygote_type_stability.jl b/test/zygote_type_stability.jl index 1338ca229..cd7af49d1 100644 --- a/test/zygote_type_stability.jl +++ b/test/zygote_type_stability.jl @@ -1,4 +1,4 @@ -using Lux, Random, Zygote, StableRNGs, Test +using Lux, Random, Zygote, StableRNGs, Test, TestExtras include("setup_modes.jl") @@ -75,13 +75,13 @@ include("setup_modes.jl") ps, st = Lux.setup(rng, model) |> dev x = input |> dev - @test @inferred(model(x, ps, Lux.testmode(st))) isa Any - @test @inferred(loss_function(model, x, ps, Lux.testmode(st))) isa Number + @constinferred model(x, ps, Lux.testmode(st)) + @constinferred loss_function(model, x, ps, Lux.testmode(st)) + if mode == "amdgpu" && model isa Conv - @test_broken @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa - Any + @constinferred_broken Zygote.gradient(loss_function, model, x, ps, st) else - @test @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa Any + @constinferred Zygote.gradient(loss_function, model, x, ps, st) end end end