Skip to content

Commit

Permalink
test: try bypassing the world age issues
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 27, 2024
1 parent a5ce201 commit 21dde54
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 28 deletions.
12 changes: 6 additions & 6 deletions lib/LuxLib/test/common_ops/conv_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module ConvSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib

expand(_, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
Expand Down Expand Up @@ -43,19 +43,19 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,

@test eltype(y) == promote_type(Tw, Tx)

@constinferred fused_conv_bias_activation(activation, weight, x, bias, cdims)
@test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any
@jet fused_conv_bias_activation(activation, weight, x, bias, cdims)

if mode != "amdgpu" && activation !== anonact
@constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)
@test @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) isa Any
else
try
@constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)
@test @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) isa Any
catch e
e isa ErrorException || rethrow()
@constinferred_broken Zygote.gradient(
@test_broken @inferred(Zygote.gradient(
sumabs2conv, activation, weight, x, bias, cdims
)
))
end
end

Expand Down
6 changes: 3 additions & 3 deletions lib/LuxLib/test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module DenseSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs, TestExtras
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs

anonact = x -> x^3

Expand Down Expand Up @@ -27,14 +27,14 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu
@test y y_generic
@test eltype(y) == promote_type(Tw, Tx)

@constinferred fused_dense_bias_activation(activation, w, x, bias)
@test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any
@jet fused_dense_bias_activation(activation, w, x, bias)

atol = 1.0f-3
rtol = 1.0f-3

if activation !== anonact
@constinferred Zygote.gradient(sumabs2dense, activation, w, x, bias)
@test @inferred(Zygote.gradient(sumabs2dense, activation, w, x, bias)) isa Any
end

skip_backends = []
Expand Down
8 changes: 4 additions & 4 deletions lib/LuxLib/test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act,
end
end

@constinferred batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)
@test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa
Any
@jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)

@test y isa aType{T, length(sz)}
Expand All @@ -88,9 +89,8 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act,
end

if anonact !== act
lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm(
x, sc, b, rm, rv, tr, act, ϵ)))
@constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, training, act, epsilon)
@test @inferred(Zygote.gradient(
sumabs2first, x, scale, bias, rm, rv, training, act, epsilon)) isa Any
end
end

Expand Down
8 changes: 4 additions & 4 deletions lib/LuxLib/test/normalization/groupnorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module GroupNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs, TestExtras
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs
using LuxTestUtils: check_approx

function setup_groupnorm(rng, aType, T, sz, affine)
Expand Down Expand Up @@ -58,12 +58,12 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu)
@test ∂bias∂bias_simple atol=atol rtol=rtol
end

@constinferred groupnorm(x, scale, bias, groups, act, epsilon)
@test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any
@jet groupnorm(x, scale, bias, groups, act, epsilon)

if anonact !== act
lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ))
@constinferred Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)
@test @inferred(Zygote.gradient(
sumabs2groupnorm, x, scale, bias, groups, act, epsilon)) isa Any
end

@test y isa aType{T, length(sz)}
Expand Down
15 changes: 8 additions & 7 deletions lib/LuxLib/test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module InstanceNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib

is_training(::Val{training}) where {training} = training

Expand All @@ -24,12 +24,12 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType)
atol = 1.0f-2
rtol = 1.0f-2

@constinferred instancenorm(x, scale, bias, training, act, epsilon)
@test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any
@jet instancenorm(x, scale, bias, training, act, epsilon)

if anonact !== act && is_training(training)
lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ)))
@constinferred Zygote.gradient(lfn, x, scale, bias, act, epsilon)
@test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any
end

@test y isa aType{T, length(sz)}
Expand All @@ -46,13 +46,14 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType)

y, nt = instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon)

@constinferred instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon)
@test @inferred(instancenorm(
x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa Any
@jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon)

if anonact !== act && is_training(training)
lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm(
x, sc, b, rm, rv, Val(true), act, m, ϵ)))
@constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)
@test @inferred(Zygote.gradient(
sumabs2instancenorm, x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa
Any
end

@test y isa aType{T, length(sz)}
Expand Down
6 changes: 3 additions & 3 deletions lib/LuxLib/test/normalization/layernorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function run_layernorm_testing_core(
epsilon = LuxLib.Utils.default_epsilon(T)
_f = (args...) -> layernorm(args..., act, dims, epsilon)

@constinferred layernorm(x, scale, bias, act, dims, epsilon)
@test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any
@jet layernorm(x, scale, bias, act, dims, epsilon)

y = _f(x, scale, bias)
Expand All @@ -60,8 +60,8 @@ function run_layernorm_testing_core(
soft_fail=[AutoFiniteDiff()])

if anonact !== act
lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ))
@constinferred Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)
@test @inferred(Zygote.gradient(
sumabs2layernorm, x, scale, bias, act, dims, epsilon)) isa Any
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ SimpleChains = "0.4.7"
StableRNGs = "1.0.2"
Static = "1"
StaticArrays = "1.9"
Statistics = "1.11.1"
Statistics = "1.10"
Test = "1.10"
TestExtras = "0.3.1"
Tracker = "0.2.36"
Expand Down

0 comments on commit 21dde54

Please sign in to comment.