Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: try re-enabling enzyme testing on 0.13.16 #1042

Merged
merged 21 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.3.3"
version = "1.3.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -69,15 +69,15 @@ LuxZygoteExt = "Zygote"
ADTypes = "1.10"
Adapt = "4.1"
ArgCheck = "2.3"
ArrayInterface = "7.10"
ArrayInterface = "7.17.1"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
Compat = "4.16"
ComponentArrays = "0.15.18"
ConcreteStructs = "0.2.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.13"
EnzymeCore = "0.8.5"
Enzyme = "0.13.16"
EnzymeCore = "0.8.6"
FastClosures = "0.3.2"
Flux = "0.14.25"
ForwardDiff = "0.10.36"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ ChainRulesCore = "1.24"
ComponentArrays = "0.15.18"
Documenter = "1.4"
DocumenterVitepress = "0.1.3"
Enzyme = "0.13.13"
Enzyme = "0.13.16"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
Functors = "0.5"
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxEnzymeExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function Lux.Training.compute_gradients_impl(
ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F}
dps = Lux.Training.dparameters(ts.cache)
dps = fmap(Utils.zero, ts.parameters; exclude=isleaf)

obj_fn_wrap, st_wrap, stats_wrap = Lux.Training.wrap_objective_function(
obj_fn, ts.model, ts.parameters, ts.states, data, True())
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ArrayInterface = "7.9"
ChainRulesCore = "1.24"
Compat = "4.16"
DispatchDoctor = "0.4.10"
EnzymeCore = "0.8.5"
EnzymeCore = "0.8.6"
Functors = "0.5"
MLDataDevices = "1.6"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxCore/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Aqua = "0.8.7"
EnzymeCore = "0.8.5"
EnzymeCore = "0.8.6"
ExplicitImports = "1.9.0"
Functors = "0.5"
MLDataDevices = "1.6"
Expand Down
6 changes: 3 additions & 3 deletions lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.3.8"
version = "1.3.9"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -65,8 +65,8 @@ ChainRulesCore = "1.24"
Compat = "4.16"
CpuId = "0.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.13"
EnzymeCore = "0.8.5"
Enzyme = "0.13.16"
EnzymeCore = "0.8.6"
FastClosures = "0.3.2"
ForwardDiff = "0.10.36"
Hwloc = "3.2"
Expand Down
17 changes: 15 additions & 2 deletions lib/LuxLib/ext/LuxLibTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module LuxLibTrackerExt

using FastClosures: @closure
using LuxLib: LuxLib, Utils, Traits
using LuxLib: LuxLib, Utils, Impl, Traits, GenericBroadcastOp
using NNlib: NNlib
using Static: True, StaticBool
using Tracker: Tracker, TrackedArray, TrackedReal, TrackedVector
using Tracker: Tracker, TrackedArray, TrackedReal, TrackedVector, TrackedMatrix

tracker_data(x) = Tracker.data(x)
tracker_data(x::NNlib.BatchedAdjoint) = NNlib.batched_adjoint(tracker_data(parent(x)))
Expand Down Expand Up @@ -52,6 +52,19 @@ for op in (:batched_mul, :batched_matmul)
end
end

# Overload muladd for Traced Arrays
for AType in (:TrackedMatrix, :AbstractMatrix),
xType in (:TrackedMatrix, :AbstractMatrix),
bType in (:TrackedVector, :AbstractVector)

Utils.is_tracked(AType, xType, bType) || continue

@eval function Impl.matmuladd(
::GenericBroadcastOp, A::$(AType), x::$(xType), b::$(bType))
return A * x .+ b
end
end

# NNlib: gather
Tracker.@grad_from_chainrules NNlib.gather!(
dst::AbstractArray, src::TrackedArray, idx::AbstractArray)
Expand Down
16 changes: 9 additions & 7 deletions lib/LuxLib/src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ function batched_matmul_cpu!(z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
batched_matmul_loopvec_impl!(z, x, y)
return
end
# Avoid an Enzyme segfault https://github.com/EnzymeAD/Enzyme.jl/issues/1983
fallback_batched_matmul!(z, LoopedArrayOp(), x, y)
# NNlib.batched_mul!(z, x, y) # XXX: restore once the enzyme segfault is fixed
if Utils.within_enzyme_autodiff()
# XXX: https://github.com/LuxDL/Lux.jl/issues/1024
fallback_batched_matmul!(z, LoopedArrayOp(), x, y)
else
NNlib.batched_mul!(z, x, y)
end
return
end

Expand All @@ -80,10 +83,9 @@ end
function fallback_batched_matmul!(
z::AbstractArray{zT, 3}, opmode, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {zT, xT, yT}
# XXX: bring back once the enzyme segfault is fixed
# @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \
# $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \
# slow." maxlog=1
@warn "Using fallback Batched Matrix Multiply routine for $(opmode) with A: size = \
$(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \
slow." maxlog=1

if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) ||
(size(x, 2) != size(y, 1))
Expand Down
40 changes: 31 additions & 9 deletions lib/LuxLib/src/impl/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,29 @@ function batchnorm_affine_normalize_internal!(
end

function compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ)
if γ === nothing && β === nothing
@simd ivdep for J in eachindex(γ′, β′, μ, σ²)
@fastmath @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ))
@fastmath @inbounds β′[J] = -μ[J] * γ′[J]
if Utils.within_enzyme_autodiff()
if γ === nothing && β === nothing
for J in eachindex(γ′, β′, μ, σ²)
@inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ))
@inbounds β′[J] = -μ[J] * γ′[J]
end
else
for J in eachindex(γ′, β′, γ, β, μ, σ²)
@inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ)
@inbounds β′[J] = β[J] - μ[J] * γ′[J]
end
end
else
@simd ivdep for J in eachindex(γ′, β′, γ, β, μ, σ²)
@fastmath @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ)
@fastmath @inbounds β′[J] = β[J] - μ[J] * γ′[J]
if γ === nothing && β === nothing
@simd ivdep for J in eachindex(γ′, β′, μ, σ²)
@fastmath @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ))
@fastmath @inbounds β′[J] = -μ[J] * γ′[J]
end
else
@simd ivdep for J in eachindex(γ′, β′, γ, β, μ, σ²)
@fastmath @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ)
@fastmath @inbounds β′[J] = β[J] - μ[J] * γ′[J]
end
end
end
end
Expand All @@ -115,7 +129,11 @@ function apply_batchnorm_scale_bias_act_cpu!(
if size(y, 1) == 1
apply_batchnorm_scale_bias_act_2d_serial_cpu!(y, γ′, β′, x, σ)
else
apply_batchnorm_scale_bias_act_3d_threaded_cpu!(y, γ′, β′, x, σ)
if Utils.within_enzyme_autodiff()
apply_batchnorm_scale_bias_act_3d_serial_cpu!(y, γ′, β′, x, σ)
else
apply_batchnorm_scale_bias_act_3d_threaded_cpu!(y, γ′, β′, x, σ)
end
end
end

Expand Down Expand Up @@ -160,7 +178,11 @@ function apply_batchnorm_scale_bias_cpu!(y::AbstractArray{yT, 3}, γ′::Abstrac
if size(y, 1) == 1
apply_batchnorm_scale_bias_2d_serial_cpu!(y, γ′, β′, x)
else
apply_batchnorm_scale_bias_3d_threaded_cpu!(y, γ′, β′, x)
if Utils.within_enzyme_autodiff()
apply_batchnorm_scale_bias_3d_serial_cpu!(y, γ′, β′, x)
else
apply_batchnorm_scale_bias_3d_threaded_cpu!(y, γ′, β′, x)
end
end
end

Expand Down
1 change: 1 addition & 0 deletions lib/LuxLib/src/impl/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ reshape_norm_dims(y, x) = reshape(x, get_norm_reshape_dims(size(y), length(x)))
end

CRC.@non_differentiable get_norm_reshape_dims(::Any...)
EnzymeRules.inactive(::typeof(get_norm_reshape_dims), ::Any...) = true

# Entry Points
## InstanceNorm
Expand Down
4 changes: 3 additions & 1 deletion lib/LuxLib/src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ end
module System

using ChainRulesCore: ChainRulesCore
using EnzymeCore: EnzymeCore
using Hwloc: Hwloc
using Static: static, False, True

using ..LuxLib: DISABLE_LOOP_VECTORIZATION
using ..Utils: is_extension_loaded, safe_minimum
using ..Utils: is_extension_loaded, safe_minimum, within_enzyme_autodiff

const CRC = ChainRulesCore

Expand Down Expand Up @@ -135,6 +136,7 @@ CRC.@non_differentiable explicit_blas_loaded()
use_octavian() = False()
else
function use_octavian()
within_enzyme_autodiff() && return False()
return is_extension_loaded(Val(:Octavian)) & is_x86_64() &
(INTEL_HARDWARE | AMD_RYZEN_HARDWARE)
end
Expand Down
7 changes: 7 additions & 0 deletions lib/LuxLib/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ within_autodiff(::AbstractArray{<:ForwardDiff.Dual}) = True()

CRC.rrule(::typeof(within_autodiff), x) = True(), _ -> (∂∅, ∂∅)

function within_enzyme_autodiff()
unsafe_known(is_extension_loaded(Val(:Enzyme))) && return EnzymeCore.within_autodiff()
return false
end

static_training_mode(::Nothing, args...) = within_autodiff_vararg(args...)

function static_training_mode(
Expand Down Expand Up @@ -329,6 +334,8 @@ CRC.@non_differentiable static_training_mode_check(::Any...)
@inline can_loopvec_args(args...) = false
else
@inline function can_loopvec_args(args...)
# Avoid loop vectorization inside Enzyme autodiff calls
within_enzyme_autodiff() && return false
return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...)
end
end
Expand Down
4 changes: 2 additions & 2 deletions lib/LuxLib/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ BLISBLAS = "0.1"
BenchmarkTools = "1.5"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.18"
Enzyme = "0.13.13"
EnzymeCore = "0.8.5"
Enzyme = "0.13.16"
EnzymeCore = "0.8.6"
ExplicitImports = "1.9.0"
ForwardDiff = "0.10.36"
Hwloc = "3.2"
Expand Down
10 changes: 6 additions & 4 deletions lib/LuxLib/test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
@testitem "Activation Functions" tags=[:misc] setup=[SharedTestSetup] begin
using Enzyme

rng = StableRNG(1234)

apply_act(f::F, x) where {F} = sum(abs2, f.(x))
Expand All @@ -8,7 +10,7 @@
@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus,
logsigmoid, gelu, swish, lisht, tanh, tanh_fast],
T in [Float16, Float32, Float64]
T in [Float32, Float64]

!fp64 && T == Float64 && continue

Expand Down Expand Up @@ -41,9 +43,9 @@
end
@test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any

@test_gradients(Base.Fix1(apply_act, f), x; atol, rtol)
@test_gradients(Base.Fix1(apply_act_fast, f), x; atol, rtol)
@test_gradients(Base.Fix1(apply_act_fast2, f), x; atol, rtol)
@test_gradients(apply_act, f, x; atol, rtol)
@test_gradients(apply_act_fast, f, x; atol, rtol, skip_backends=[AutoEnzyme()])
@test_gradients(apply_act_fast2, f, x; atol, rtol)

∂x1 = Zygote.gradient(apply_act, f, x)[2]
∂x2 = Zygote.gradient(apply_act_fast, f, x)[2]
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/test/common_ops/bias_act_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@testset "$act, $T, $sz" for act in [
identity, relu, sigmoid, sigmoid_fast, softplus,
logsigmoid, gelu, swish, lisht, tanh, tanh_fast],
T in [Float16, Float32, Float64],
T in [Float32, Float64],
sz in [(2, 2, 3, 4), (4, 5)]

!fp64 && T == Float64 && continue
Expand Down
30 changes: 14 additions & 16 deletions lib/LuxLib/test/common_ops/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ end

calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} = expand(Val(2 * N), pad)

sumabs2conv(args...) = sum(abs2, fused_conv_bias_activation(args...))

function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,
hasbias, groups, Tw, Tx, aType, mode, ongpu)
weight = convfilter(gen_f, Tw, kernel, 4 => 8; groups) |> aType
Expand All @@ -28,9 +30,8 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,

generic_testing = !(mode == "amdgpu" && (Tx == Float64 || Tw == Float64))

fp16 = Tx == Float16 || Tw == Float16
atol = fp16 ? 1.0f-1 : 1.0f-3
rtol = fp16 ? 1.0f-1 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

if generic_testing
y_generic = LuxLib.Impl.conv(x, weight, cdims)
Expand All @@ -45,36 +46,33 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,
@test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any
@jet fused_conv_bias_activation(activation, weight, x, bias, cdims)

__f = (σ, w, x, b, cdims) -> sum(abs2, fused_conv_bias_activation(σ, w, x, b, cdims))

if mode != "amdgpu" && activation !== anonact && !fp16
@test @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) isa Any
if mode != "amdgpu" && activation !== anonact
@test @inferred(Zygote.gradient(
sumabs2conv, activation, weight, x, bias, cdims
)) isa Any
else
try
@inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims))
@inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims))
@test true
catch e
e isa ErrorException || rethrow()
@test_broken false
end
end

__f_grad = let activation = activation, cdims = cdims
(w, x, b) -> __f(activation, w, x, b, cdims)
end

skip_backends = Any[AutoEnzyme()]
skip_backends = []
mp = Tx != Tw
mp && push!(skip_backends, AutoReverseDiff())
((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) &&
push!(skip_backends, AutoTracker())
@test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, soft_fail=fp16)

@test_gradients(sumabs2conv, activation, weight, x, bias, cdims; atol, rtol,
skip_backends)
end

anonact = x -> gelu(x)

const ELTYPES = [(Float16, Float16), (Float32, Float16), (Float32, Float32),
(Float32, Float64), (Float64, Float64)]
const ELTYPES = [(Float32, Float32), (Float32, Float64), (Float64, Float64)]
const ACTIVATIONS = [
identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, swish, anonact]

Expand Down
Loading
Loading