Skip to content

Commit

Permalink
fix: avoid LV or Octavian with Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 18, 2024
1 parent ca6218f commit 40579d1
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 70 deletions.
2 changes: 1 addition & 1 deletion lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.3.8"
version = "1.3.9"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
5 changes: 4 additions & 1 deletion lib/LuxLib/src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ end
module System

using ChainRulesCore: ChainRulesCore
using EnzymeCore: EnzymeCore
using Hwloc: Hwloc
using Static: static, False, True

using ..LuxLib: DISABLE_LOOP_VECTORIZATION
using ..Utils: is_extension_loaded, safe_minimum
using ..Utils: is_extension_loaded, safe_minimum, unsafe_known

const CRC = ChainRulesCore

Expand Down Expand Up @@ -135,6 +136,8 @@ CRC.@non_differentiable explicit_blas_loaded()
use_octavian() = False()
else
function use_octavian()
unsafe_known(is_extension_loaded(Val(:Enzyme))) && EnzymeCore.within_autodiff() &&
return False()
return is_extension_loaded(Val(:Octavian)) & is_x86_64() &
(INTEL_HARDWARE | AMD_RYZEN_HARDWARE)
end
Expand Down
3 changes: 3 additions & 0 deletions lib/LuxLib/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ CRC.@non_differentiable static_training_mode_check(::Any...)
@inline can_loopvec_args(args...) = false
else
@inline function can_loopvec_args(args...)
# Avoid loop vectorization inside Enzyme autodiff calls
unsafe_known(is_extension_loaded(Val(:Enzyme))) && EnzymeCore.within_autodiff() &&
return false
return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...)
end
end
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxTestUtils/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N}) where {N}
y = length(x) > 1 ? reduce(vcat, x) : reduce(vcat, [x[1], x[1]])[1:1]
return reshape(y, size(x))
end
aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where {N} = Tracker.collect(x)
aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal, N}) where {N} = Tracker.collect(x)

function needs_gradient(y)
leaves = Functors.fleaves(y)
Expand Down
40 changes: 0 additions & 40 deletions test/helpers/size_propagator_test.jl

This file was deleted.

7 changes: 4 additions & 3 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
@jet layer(x, ps, st)

x = randn(rng, 6, 3) |> aType
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends=[AutoReverseDiff()])
end

@testset "SelectDim Layer" begin
Expand Down Expand Up @@ -299,8 +300,8 @@ end

@test LuxCore.outputsize(layer, (x, y), rng) == (3,)
@jet layer((x, y), ps, st)
@test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoEnzyme()])
@test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3,
rtol=1.0f-3, skip_backends=[AutoEnzyme()])
end

@testset "Inner interactions" begin
Expand Down
47 changes: 23 additions & 24 deletions test/layers/normalize_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,18 @@
@test x_[1](1 .- 0.3) / sqrt(1.3) atol=1.0e-5

@jet m(x, ps, st)
__f = (x, ps) -> sum(first(m(x, ps, st)))
@test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoFiniteDiff()])
@test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3,
rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()])

for affine in (true, false)
@testset for affine in (true, false)
m = BatchNorm(2; affine, track_stats=false)
x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] |> aType
display(m)
ps, st = Lux.setup(rng, m) |> dev

@jet m(x, ps, Lux.testmode(st))
@test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoFiniteDiff()])
@test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3,
rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()])

# with activation function
m = BatchNorm(2, sigmoid; affine)
Expand All @@ -68,16 +67,8 @@
@test y
sigmoid.((x .- st_.running_mean) ./ sqrt.(st_.running_var .+ m.epsilon))
@jet m(x, ps, Lux.testmode(st))

if affine
__f = (x, ps) -> sum(first(m(x, ps, st)))
@test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoFiniteDiff()])
else
__f = x -> sum(first(m(x, ps, st)))
@test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoFiniteDiff()])
end
@test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3,
rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()])

m = BatchNorm(32; affine)
x = randn(Float32, 416, 416, 32, 1) |> aType
Expand Down Expand Up @@ -170,31 +161,35 @@ end
x = randn(rng, Float32, 3, 3, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends=[AutoEnzyme()])

wn = WeightNorm(c, (:weight,))
display(wn)
ps, st = Lux.setup(rng, wn) |> dev
x = randn(rng, Float32, 3, 3, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends=[AutoEnzyme()])

wn = WeightNorm(c, (:weight, :bias), (2, 2))
display(wn)
ps, st = Lux.setup(rng, wn) |> dev
x = randn(rng, Float32, 3, 3, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends=[AutoEnzyme()])

wn = WeightNorm(c, (:weight,), (2,))
display(wn)
ps, st = Lux.setup(rng, wn) |> dev
x = randn(rng, Float32, 3, 3, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends=[AutoEnzyme()])
end

@testset "Dense" begin
Expand All @@ -206,31 +201,35 @@ end
x = randn(rng, Float32, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends=[AutoEnzyme()])

wn = WeightNorm(d, (:weight,))
display(wn)
ps, st = Lux.setup(rng, wn) |> dev
x = randn(rng, Float32, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends=[AutoEnzyme()])

wn = WeightNorm(d, (:weight, :bias), (2, 2))
display(wn)
ps, st = Lux.setup(rng, wn) |> dev
x = randn(rng, Float32, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends=[AutoEnzyme()])

wn = WeightNorm(d, (:weight,), (2,))
display(wn)
ps, st = Lux.setup(rng, wn) |> dev
x = randn(rng, Float32, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends=[AutoEnzyme()])
end

# See https://github.com/LuxDL/Lux.jl/issues/95
Expand Down

0 comments on commit 40579d1

Please sign in to comment.