Skip to content

Commit

Permalink
fix: more test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 21, 2024
1 parent a1ea977 commit c680055
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 32 deletions.
17 changes: 15 additions & 2 deletions lib/LuxLib/ext/LuxLibTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module LuxLibTrackerExt

using FastClosures: @closure
using LuxLib: LuxLib, Utils, Traits
using LuxLib: LuxLib, Utils, Impl, Traits, GenericBroadcastOp
using NNlib: NNlib
using Static: True, StaticBool
using Tracker: Tracker, TrackedArray, TrackedReal, TrackedVector
using Tracker: Tracker, TrackedArray, TrackedReal, TrackedVector, TrackedMatrix

tracker_data(x) = Tracker.data(x)
tracker_data(x::NNlib.BatchedAdjoint) = NNlib.batched_adjoint(tracker_data(parent(x)))
Expand Down Expand Up @@ -52,6 +52,19 @@ for op in (:batched_mul, :batched_matmul)
end
end

# Overload muladd for Traced Arrays
for AType in (:TrackedMatrix, :AbstractMatrix),
xType in (:TrackedMatrix, :AbstractMatrix),
bType in (:TrackedVector, :AbstractVector)

Utils.is_tracked(AType, xType, bType) || continue

@eval function Impl.matmuladd(
::GenericBroadcastOp, A::$(AType), x::$(xType), b::$(bType))
return A * x .+ b
end
end

# NNlib: gather
Tracker.@grad_from_chainrules NNlib.gather!(
dst::AbstractArray, src::TrackedArray, idx::AbstractArray)
Expand Down
11 changes: 4 additions & 7 deletions lib/LuxLib/src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ function batched_matmul_cpu!(z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
batched_matmul_loopvec_impl!(z, x, y)
return
end
# Avoid an Enzyme segfault https://github.com/EnzymeAD/Enzyme.jl/issues/1983
fallback_batched_matmul!(z, LoopedArrayOp(), x, y)
# NNlib.batched_mul!(z, x, y) # XXX: restore once the enzyme segfault is fixed
NNlib.batched_mul!(z, x, y)
return
end

Expand All @@ -80,10 +78,9 @@ end
function fallback_batched_matmul!(
z::AbstractArray{zT, 3}, opmode, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {zT, xT, yT}
# XXX: bring back once the enzyme segfault is fixed
# @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \
# $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \
# slow." maxlog=1
@warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \
$(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \
slow." maxlog=1

if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) ||
(size(x, 2) != size(y, 1))
Expand Down
12 changes: 6 additions & 6 deletions lib/LuxLib/test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act,
end

if is_training(training)
# XXX: Fails due to runtime activity but setting it doesn't help
@test_gradients(sumabs2first, batchnorm, x, scale, bias, Constant(rm),
Constant(rv), training, act, T(0.9), epsilon; atol, rtol,
enzyme_set_runtime_activity=true)
skip_backends=[AutoEnzyme()], enzyme_set_runtime_activity=true)
end

if anonact !== act
Expand Down Expand Up @@ -130,8 +131,6 @@ end

@testitem "Batch Norm: Mixed Precision" tags=[:normalization] setup=[SharedTestSetup] begin
@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
!fp64 && aType == Float64 && continue

x = rand(Float64, 4, 4, 6, 2) |> aType
scale = rand(Float32, 6) |> aType
bias = rand(Float32, 6) |> aType
Expand All @@ -144,8 +143,9 @@ end
@test nt.running_mean isa aType && length(nt.running_mean) == 6
@test nt.running_var isa aType && length(nt.running_var) == 6

__f = (args...) -> sum(first(batchnorm(
args..., running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5)))
@test_gradients(__f, x, scale, bias; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(
sumabs2first, batchnorm, x, scale, bias, Constant(running_mean),
Constant(running_var), training, act, T(0.9), T(1e-5); atol=1.0f-3, rtol=1.0f-3
)
end
end
10 changes: 6 additions & 4 deletions lib/LuxLib/test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType)
# First test without running stats
y, nt = instancenorm(x, scale, bias, training, act, epsilon)

atol = 1.0f-3
rtol = 1.0f-3
atol = 1.0f-2
rtol = 1.0f-2

@test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any
@jet instancenorm(x, scale, bias, training, act, epsilon)
Expand All @@ -37,7 +37,7 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType)

if is_training(training)
@test_gradients(sumabs2instancenorm, x, scale, bias, training, act, epsilon;
atol, rtol, soft_fail=[AutoFiniteDiff()])
atol, rtol, soft_fail=[AutoFiniteDiff()], enzyme_set_runtime_activity=true)
end

# Now test with running stats
Expand All @@ -62,7 +62,9 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType)

if is_training(training)
@test_gradients(sumabs2instancenorm, x, scale, bias, Constant(rm), Constant(rv),
training, act, T(0.1), epsilon; atol, rtol, soft_fail=[AutoFiniteDiff()])
training, act, T(0.1), epsilon; atol, rtol,
soft_fail=[AutoFiniteDiff()],
enzyme_set_runtime_activity=true)
end
end

Expand Down
3 changes: 2 additions & 1 deletion lib/LuxTestUtils/src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[],
end
catch err
err isa InterruptException && rethrow()
Error(:test_error, local_test_expr, err, Base.current_exceptions(), source)
Error(:test_error, local_test_expr, err,
Base.current_exceptions(), source)
end
end
Test.record(get_testset(), result)
Expand Down
15 changes: 6 additions & 9 deletions test/helpers/loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,8 @@ end

@test @inferred(Zygote.gradient(celoss, ŷ, y)) isa Any

# XXX: Failure only on CI
@test_gradients(Base.Fix2(celoss, y), ŷ; atol=1.0f-3, rtol=1.0f-3)
# rtol=1.0f-3, skip_backends=VERSION ≥ v"1.11-" ? [AutoEnzyme()] : [])
@test_gradients(Base.Fix2(celoss, y), ŷ; atol=1.0f-3,
rtol=1.0f-3, skip_backends=VERSION v"1.11-" ? [AutoEnzyme()] : [])
end

@testset "Logit CrossEntropyLoss" begin
Expand All @@ -176,9 +175,8 @@ end

@test @inferred(Zygote.gradient(logitceloss, logŷ, y)) isa Any

# XXX: Failure only on CI
@test_gradients(Base.Fix2(logitceloss, y), logŷ; atol=1.0f-3, rtol=1.0f-3)
# rtol=1.0f-3, skip_backends=VERSION ≥ v"1.11-" ? [AutoEnzyme()] : [])
@test_gradients(Base.Fix2(logitceloss, y), logŷ; atol=1.0f-3,
rtol=1.0f-3, skip_backends=VERSION v"1.11-" ? [AutoEnzyme()] : [])
end

logŷ, y = randn(3) |> aType, rand(3) |> aType
Expand Down Expand Up @@ -305,9 +303,8 @@ end
@jet KLDivergenceLoss()(ŷ, y)
@test @inferred(Zygote.gradient(KLDivergenceLoss(), ŷ, y)) isa Any

# XXX: Failure only on CI
@test_gradients(Base.Fix2(KLDivergenceLoss(), y), ŷ; atol=1.0f-3, rtol=1.0f-3)
# rtol=1.0f-3, skip_backends=VERSION ≥ v"1.11-" ? [AutoEnzyme()] : [])
@test_gradients(Base.Fix2(KLDivergenceLoss(), y), ŷ; atol=1.0f-3,
rtol=1.0f-3, skip_backends=VERSION v"1.11-" ? [AutoEnzyme()] : [])
end

@testset "HingeLoss" begin
Expand Down
3 changes: 0 additions & 3 deletions test/layers/recurrent_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ end
@test !hasproperty(ps, :hidden_state)
end

# XXX: Failure only on CI
# skip_backends = VERSION ≥ v"1.11-" && act === identity ? [AutoEnzyme()] : []
@test_gradients(loss_loop, rnncell, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
# skip_backends)
end
end

Expand Down

0 comments on commit c680055

Please sign in to comment.