diff --git a/Project.toml b/Project.toml index 5f86abc9d1..9cdf88807d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.4.2" +version = "1.4.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -83,7 +83,7 @@ Adapt = "4.1" ArgCheck = "2.3" ArrayInterface = "7.17.1" CUDA = "5.3.2" -ChainRulesCore = "1.24" +ChainRulesCore = "1.25" Compat = "4.16" ComponentArrays = "0.15.18" ConcreteStructs = "0.2.3" @@ -106,11 +106,11 @@ MPI = "0.20.19" MacroTools = "0.5.13" Markdown = "1.10" NCCL = "0.1.1" -NNlib = "0.9.24" +NNlib = "0.9.26" Optimisers = "0.4.1" Preferences = "1.4.3" Random = "1.10" -Reactant = "0.2.8" +Reactant = "0.2.12" Reexport = "1.2.2" ReverseDiff = "1.15" SIMDTypes = "0.1" diff --git a/docs/Project.toml b/docs/Project.toml index 3eb44b24ef..0e561d7625 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -33,7 +33,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1.10" Adapt = "4" -ChainRulesCore = "1.24" +ChainRulesCore = "1.25" ComponentArrays = "0.15.18" Documenter = "1.4" DocumenterVitepress = "0.1.3" @@ -51,12 +51,12 @@ LuxCore = "1.2" LuxLib = "1.3.4" LuxTestUtils = "1.5" MLDataDevices = "1.6" -NNlib = "0.9.24" +NNlib = "0.9.26" Optimisers = "0.4.1" Pkg = "1.10" Printf = "1.10" Random = "1.10" -Reactant = "0.2.8" +Reactant = "0.2.12" StableRNGs = "1" StaticArrays = "1" WeightInitializers = "1" diff --git a/docs/make.jl b/docs/make.jl index fac7081d55..c9f2e98a3c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -29,7 +29,7 @@ pages = [ "tutorials/intermediate/1_NeuralODE.md", "tutorials/intermediate/2_BayesianNN.md", "tutorials/intermediate/3_HyperNet.md", - "tutorials/intermediate/4_PINN2DPDE.md" + "tutorials/intermediate/4_PINN2DPDE.md", ], "Advanced" => [ "tutorials/advanced/1_GravitationalWaveForm.md" diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 89ce8831fe..6f49b076d0 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -2,13 +2,22 @@ module LuxReactantExt using Enzyme: Enzyme, Const, Duplicated, Active using Optimisers: Optimisers -using Reactant: Reactant, @compile, TracedRArray, TracedRNumber +using Reactant: Reactant, @compile, AnyTracedRArray, TracedRArray, TracedRNumber using Setfield: @set! using Static: False -using Lux: Lux, LuxOps, Training +using Lux: Lux, LuxOps, Training, Utils using Lux.Training: TrainingBackendCache, ReactantBackend +Lux.is_extension_loaded(::Val{:Reactant}) = true + +Utils.to_rarray(x; kwargs...) = Reactant.to_rarray(x; kwargs...) + +function Utils.promote_to(::Type{T}, x::Number) where {T <: Number} + x isa Reactant.TracedType && return x + return Reactant.ConcreteRNumber{T}(x) +end + include("patches.jl") include("training.jl") diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index 8b13789179..f9f4519e0a 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -1 +1,4 @@ +Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(vec(x)) +# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint +Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index 605093e1ea..c35d5cb054 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -1,3 +1,28 @@ +mutable struct StatsAndNewStateWrapper + stats::Any + st::Any +end + +function wrapped_objective_function( + fn::F, model, ps, st, data, cache::StatsAndNewStateWrapper +) where {F} + loss, stₙ, stats = fn(model, ps, st, data) + cache.stats = stats + cache.st = stₙ + return loss +end + +function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F} + stats_wrapper = StatsAndNewStateWrapper(nothing, nothing) + res = Enzyme.gradient( + Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), + Const(wrapped_objective_function), Const(objective_function), + Const(model), ps, Const(st), Const(data), Const(stats_wrapper) + ) + loss, dps = res.val, res.derivs[3] + return dps, loss, stats_wrapper.stats, stats_wrapper.st +end + function Lux.Training.compute_gradients_impl( backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} @@ -22,18 +47,33 @@ function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data, return grads, loss, stats, ts end -function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F} - dps = Enzyme.make_zero(ps) - _, (loss, stₙ, stats) = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), - Duplicated(ps, dps), Const(st), Const(data)) - return dps, loss, stats, stₙ -end - for inplace in ("!", "") fname = Symbol(:single_train_step_impl, inplace) internal_fn = Symbol(:compute_gradients_internal_and_step, inplace) + apply_gradients_fn = Symbol(:apply_gradients, inplace) + update_fn = Symbol(:update, inplace) + + # Ideally users never hit this dispatch but it is still good to have as a fallback + @eval function Lux.Training.$(apply_gradients_fn)( + ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}}, grads + ) + if hasfield(typeof(ts.cache.extras), :update_function) + update_function = ts.cache.extras.update_function + else + update_function = @compile Optimisers.$(update_fn)( + ts.optimizer_state, ts.parameters, grads) + @set! ts.cache.extras = merge(ts.cache.extras, (; update_function)) + end + opt_state, ps = update_function(ts.optimizer_state, ts.parameters, grads) + @set! ts.parameters = ps + @set! ts.optimizer_state = opt_state + @set! ts.step = ts.step + 1 + return ts + end + + # XXX: Should we add a check to ensure the inputs to this function is same as the one + # used in the compiled function? We can re-trigger the compilation with a warning @eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} compiled_grad_and_step_function = @compile $(internal_fn)( @@ -68,27 +108,13 @@ for inplace in ("!", "") return grads, loss, stats, ts end -end -function compute_gradients_internal_and_step(objective_function::F, model, data, ps, - st, opt_state) where {F} - dps = Enzyme.make_zero(ps) - _, (loss, stₙ, stats) = Enzyme.autodiff( - Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), - Const(objective_function), Active, Const(model), - Duplicated(ps, dps), Const(st), Const(data)) - opt_state, ps = Optimisers.update(opt_state, ps, dps) - return dps, ps, loss, stats, stₙ, opt_state -end - -function compute_gradients_internal_and_step!(objective_function::F, model, data, ps, - st, opt_state) where {F} - dps = Enzyme.make_zero(ps) - _, (loss, stₙ, stats) = Enzyme.autodiff( - Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), - Const(objective_function), Active, Const(model), - Duplicated(ps, dps), Const(st), Const(data)) - # XXX: Inplace updates not actually inplace - opt_state, ps = Optimisers.update!(opt_state, ps, dps) - return dps, ps, loss, stats, stₙ, opt_state + # XXX: Inplace version not actually inplace + @eval function $(internal_fn)( + objective_function::F, model, data, ps, st, opt_state) where {F} + dps, loss, stats, stₙ = compute_gradients_internal( + objective_function, model, data, ps, st) + opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps) + return dps, ps, loss, stats, stₙ, opt_state + end end diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index aa927150be..5c22c0be3f 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.10" +version = "1.3.11" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -77,7 +77,7 @@ LuxCore = "1.2" MKL = "0.7" MLDataDevices = "1.6" Markdown = "1.10" -NNlib = "0.9.24" +NNlib = "0.9.26" Octavian = "0.3.28" Preferences = "1.4.3" Polyester = "0.7.15" diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index d7b0225937..a7234c5eb4 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -97,7 +97,7 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), Utils.is_tracked(RM, RV, S, B, XT) || continue @eval Tracker.@grad_from_chainrules LuxLib.Impl.batchnorm_cudnn( - γ::$S, β::$B, x::$XT, rμ::$RM, rσ²::$RV, m::Real, ϵ::Real, training::StaticBool) + γ::$S, β::$B, x::$XT, rμ::$RM, rσ²::$RV, m, ϵ, training::StaticBool) end # Utils extensions diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 77e59d3e4b..b35af417bc 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -21,7 +21,7 @@ include("batchnorm.jl") function Impl.batchnorm(x::Union{<:CuArray{T, 2}, <:CuArray{T, 4}, <:CuArray{T, 5}}, γ::Optional{<:CuVector{T}}, β::Optional{<:CuVector{T}}, rμ::Optional{<:CuVector{T}}, rσ²::Optional{<:CuVector{T}}, - training::StaticBool, σ::F, m::Real, ϵ::Real) where {T <: cuDNNFloat, F} + training::StaticBool, σ::F, m, ϵ) where {T <: cuDNNFloat, F} rμₙ, rσ²ₙ = Impl.get_batchnorm_statistics(x, rμ, rσ², training) y = Impl.batchnorm_cudnn(γ, β, x, rμₙ, rσ²ₙ, m, ϵ, training)[1] return Impl.activation!!(σ, y), safe_vec(rμₙ), safe_vec(rσ²ₙ) diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 05964f0c6b..bba8a5af27 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -37,7 +37,7 @@ mean and variance. function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, training::TrainingType, act::F=identity, - momentum::Real=0.1f0, epsilon::Real=default_epsilon(x)) where {F, T, N} + momentum=0.1f0, epsilon=default_epsilon(x)) where {F, T, N} σ = select_fastest_activation(act, x, γ, β, rμ, rσ²) y, rμ, rσ² = batchnorm_impl( x, γ, β, rμ, rσ², static_training_mode(training, x, γ, β, rμ, rσ²), diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 4e6a7bff86..1053ff9dfe 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -1,6 +1,6 @@ @doc doc""" groupnorm(x, scale, bias, groups::Int, σ::F=identity, - epsilon::Real=eps(eltype(x)) ^ (5 // 7)) + epsilon=eps(eltype(x)) ^ (5 // 7)) Group Normalization. For details see [1]. @@ -30,7 +30,7 @@ The normalized array is returned. """ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity, - epsilon::Real=default_epsilon(x)) where {F, N} + epsilon=default_epsilon(x)) where {F, N} assert_valid_groupnorm_arguments(x, scale, bias, groups) return groupnorm_impl( x, scale, bias, groups, select_fastest_activation(σ, x, scale, bias), epsilon) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 1587855242..259e14bb4e 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -36,7 +36,7 @@ mean and variance. """ function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, training::TrainingType, - σ::F=identity, epsilon::Real=default_epsilon(x)) where {F} + σ::F=identity, epsilon=default_epsilon(x)) where {F} # This API is kept for legacy purposes when we didn't support passing running stats return instancenorm(x, γ, β, nothing, nothing, training, σ, nothing, epsilon) end @@ -44,7 +44,7 @@ end function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, training::TrainingType, σ::F=identity, - momentum::Optional{<:Real}=0.1f0, epsilon::Real=default_epsilon(x)) where {F} + momentum::Optional{<:Real}=0.1f0, epsilon=default_epsilon(x)) where {F} assert_valid_instancenorm_arguments(x) y, rμₙ, rσ²ₙ = instancenorm_impl( diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index eb147d30ef..7148fbb0da 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -36,7 +36,7 @@ Normalized Array of same size as `x`. """ function layernorm(x::AbstractArray{xT, N}, scale::Optional{<:AbstractArray}, bias::Optional{<:AbstractArray}, σ::F=identity, dims=1:(N - 1), - epsilon::Real=default_epsilon(x)) where {F, xT, N} + epsilon=default_epsilon(x)) where {F, xT, N} return layernorm_impl( x, scale, bias, select_fastest_activation(σ, x, scale, bias), dims, epsilon) end diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index 995aacf857..d37bee3464 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -27,7 +27,7 @@ CRC.@non_differentiable get_batchnorm_statistics(::Any...) function batchnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, training::StaticBool, act::F, - momentum::Real, ϵ::Real) where {F, xT, N} + momentum, ϵ) where {F, xT, N} (μ, σ²), (rμ, rσ²) = compute_batch_statistics( x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²), batchnorm_reduce_dims(x), training, momentum) @@ -37,7 +37,7 @@ end function batchnorm_affine_normalize( act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N} + β::Optional{<:AbstractVector}, ϵ) where {F, xT, μT, σ²T, N} return batchnorm_affine_normalize( internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) end @@ -45,7 +45,7 @@ end function batchnorm_affine_normalize( ::GenericBroadcastOp, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N} + β::Optional{<:AbstractVector}, ϵ) where {F, xT, μT, σ²T, N} return affine_normalize( act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) end @@ -54,7 +54,7 @@ function batchnorm_affine_normalize( opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, - ϵ::Real) where {F, xT, μT, σ²T, N} + ϵ) where {F, xT, μT, σ²T, N} x′ = reshape(x, :, size(x, N - 1), size(x, N)) return reshape( batchnorm_affine_normalize_internal(opmode, act, x′, vec(μ), vec(σ²), γ, β, ϵ), @@ -64,7 +64,7 @@ end @stable default_mode="disable" function batchnorm_affine_normalize_internal( opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT} + β::Optional{<:AbstractVector}, ϵ) where {F, xT} y = similar(x, promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β))) @@ -75,7 +75,7 @@ end function batchnorm_affine_normalize_internal!( y::AbstractArray{yT, 3}, opmode::LoopedArrayOp, act::F, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real, + β::Optional{<:AbstractVector}, ϵ, γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT} N = size(y, 2) γ′ = γ′ === nothing ? @@ -225,7 +225,7 @@ end function batchnorm_affine_normalize_internal!( y::AbstractArray{yT, 3}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real, + β::Optional{<:AbstractVector}, ϵ, γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT} backend = KA.get_backend(y) run_ka_kernel( @@ -278,7 +278,7 @@ function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(batchnorm_affine_normalize_internal), opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{T, N}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} + β::Optional{<:AbstractVector}, ϵ) where {F, T, N} y = similar(x, promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β))) @@ -304,7 +304,7 @@ end function ∇batchnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, - γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ, γ′::AbstractVector) where {∂yT, xT} ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²) ∂γ = γ === nothing ? nothing : similar(γ) @@ -322,7 +322,7 @@ function ∇batchnorm_affine_normalize_cpu!( ∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT}, ∂σ²::AbstractVector{∂σ²T}, ::Nothing, ::Nothing, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, ::Nothing, - ϵ::Real, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT} + ϵ, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT} half = eltype(∂σ²)(0.5) fill!(∂μ, 0) @@ -361,7 +361,7 @@ function ∇batchnorm_affine_normalize_cpu!( ∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT}, ∂σ²::AbstractVector{∂σ²T}, ∂γ::AbstractVector{∂γT}, ∂β::AbstractVector{∂βT}, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, - μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ::Real, + μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT} half = eltype(∂σ²)(0.5) @@ -406,7 +406,7 @@ end function ∇batchnorm_affine_normalize( opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, - γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ, γ′::AbstractVector) where {∂yT, xT} ∂x, ∂σ² = similar(x), similar(σ², size(x)) ∂γ = γ === nothing ? nothing : similar(γ, size(x)) @@ -425,7 +425,7 @@ function ∇batchnorm_affine_normalize!( ∂x::AbstractArray{∂xT, 3}, ∂σ²::AbstractArray{∂σ²T, 3}, ∂γ::Optional{<:AbstractArray{<:Any, 3}}, ::GPUBroadcastOp, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector, - σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real, + σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ, γ′::AbstractVector) where {∂xT, ∂σ²T, ∂yT, xT} backend = KA.get_backend(∂x) run_ka_kernel( diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 5b4248291f..10dda2f69e 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -62,22 +62,22 @@ function alpha_dropout(noise::AbstractArray, p, x::AbstractArray, α, A, B) end @stable default_mode="disable" function alpha_dropout( - ::AbstractInternalArrayOpMode, noise::AbstractArray, p::Real, - x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T} + ::AbstractInternalArrayOpMode, noise::AbstractArray, p, + x::AbstractArray{T}, α, A, B) where {T} A′, B′, α = T(A), T(B), T(α) return @. muladd(ifelse(noise > p, x, α), A′, B′) end @stable default_mode="disable" function alpha_dropout( - opmode::LoopedArrayOp, noise::AbstractArray, p::Real, - x::AbstractArray, α::Real, A::Real, B::Real) + opmode::LoopedArrayOp, noise::AbstractArray, p, + x::AbstractArray, α, A, B) res = similar(x, promote_type(typeof(p), typeof(α))) alpha_dropout!(res, opmode, noise, p, x, α, A, B) return res end function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArray, - p::Real, x::AbstractArray, α::Real, A::Real, B::Real) + p, x::AbstractArray, α, A, B) cond = similar(noise, Bool) y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) @simd ivdep for I in eachindex(noise, x, y, cond) @@ -99,7 +99,7 @@ function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArra end function CRC.rrule(::typeof(alpha_dropout), ::AbstractInternalArrayOpMode, - noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) + noise::AbstractArray, p, x::AbstractArray, α, A, B) cond = noise .> p y = @. ifelse(cond, x, α) * A + B @@ -114,7 +114,7 @@ end function alpha_dropout!( res::AbstractArray{T}, ::LoopedArrayOp, noise::AbstractArray{T}, - p::Real, x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T} + p, x::AbstractArray{T}, α, A, B) where {T} @simd ivdep for I in eachindex(noise, x, res) res[I] = ifelse(noise[I] > p, x[I], α) * A + B end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 9a64fd7350..df52b8508b 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -3,7 +3,7 @@ groupnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} = ntuple(static, N - 1 CRC.@non_differentiable groupnorm_reduce_dims(::Any) function groupnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, groups::Int, act::F, ϵ::Real) where {F, N, xT} + β::Optional{<:AbstractVector}, groups::Int, act::F, ϵ) where {F, N, xT} x′ = reshape(x, size(x)[1:(N - 2)]..., size(x, N - 1) ÷ groups, groups, size(x, N)) (μ, σ²), _ = compute_batch_statistics( x′, nothing, nothing, groupnorm_reduce_dims(x), False(), nothing) @@ -13,7 +13,7 @@ end function groupnorm_affine_normalize( act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T} + β::Optional{<:AbstractVector}, ϵ) where {F, N, xT, μT, σ²T} return groupnorm_affine_normalize( internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) end @@ -21,7 +21,7 @@ end function groupnorm_affine_normalize( ::GenericBroadcastOp, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T} + β::Optional{<:AbstractVector}, ϵ) where {F, N, xT, μT, σ²T} return affine_normalize( act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) end @@ -29,7 +29,7 @@ end @generated function groupnorm_affine_normalize( opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T} + β::Optional{<:AbstractVector}, ϵ) where {F, N, xT, μT, σ²T} reshape_calls = if γ != Nothing quote γ′ = reshape(γ, 1, size(x, N - 2), size(x, N - 1), 1) @@ -57,7 +57,7 @@ end opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {F, xT, μT, σ²T} + ϵ) where {F, xT, μT, σ²T} y = similar(x, promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β))) @@ -69,7 +69,7 @@ function groupnorm_affine_normalize_internal!( y::AbstractArray{yT, 4}, opmode::LoopedArrayOp, act::F, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {F, xT, yT, μT, σ²T} + ϵ) where {F, xT, yT, μT, σ²T} if unsafe_known(fuse_cpu_activation(act)) groupnorm_affine_normalize_act_cpu!(y, x, μ, σ², γ, β, ϵ, act) else @@ -82,7 +82,7 @@ end function groupnorm_affine_normalize_act_cpu!( y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, act::F) where {F, xT, yT, μT, σ²T} + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ, act::F) where {F, xT, yT, μT, σ²T} if size(y, 1) == 1 groupnorm_affine_normalize_act_3d_serial_cpu!(y, x, μ, σ², γ, β, ϵ, act) else @@ -93,7 +93,7 @@ end function groupnorm_affine_normalize_act_3d_serial_cpu!( y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T} + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ, σ::F) where {F, xT, yT, μT, σ²T} if γ === nothing && β === nothing @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -117,7 +117,7 @@ end function groupnorm_affine_normalize_act_4d_serial_cpu!( y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T} + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ, σ::F) where {F, xT, yT, μT, σ²T} if γ === nothing && β === nothing @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -145,7 +145,7 @@ end function groupnorm_affine_normalize_cpu!( y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ) where {xT, yT, μT, σ²T} if size(y, 1) == 1 groupnorm_affine_normalize_3d_serial_cpu!(y, x, μ, σ², γ, β, ϵ) else @@ -156,7 +156,7 @@ end @inline function groupnorm_affine_normalize_3d_serial_cpu!( y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ) where {xT, yT, μT, σ²T} if γ === nothing && β === nothing @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -180,7 +180,7 @@ end @inline function groupnorm_affine_normalize_4d_serial_cpu!( y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ) where {xT, yT, μT, σ²T} if γ === nothing && β === nothing @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -209,7 +209,7 @@ function groupnorm_affine_normalize_internal!( y::AbstractArray{yT, 4}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {F, xT, yT, μT, σ²T} + ϵ) where {F, xT, yT, μT, σ²T} backend = KA.get_backend(y) run_ka_kernel( groupnorm_affine_normalize_kernel!, backend, nothing, size(y), @@ -240,7 +240,7 @@ function CRC.rrule( opmode::AbstractInternalArrayOpMode, f::F, x::AbstractArray{T, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {F, T, μT, σ²T} + ϵ) where {F, T, μT, σ²T} y = similar(x, promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β))) @@ -264,7 +264,7 @@ function ∇groupnorm_affine_normalize( opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {∂yT, xT, μT, σ²T} + ϵ) where {∂yT, xT, μT, σ²T} ∂x, ∂σ² = similar(x), similar(σ², size(x)) ∂γ = γ === nothing ? nothing : similar(γ, size(x)) @@ -281,7 +281,7 @@ end function ∇groupnorm_affine_normalize(::LoopedArrayOp, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {∂yT, xT, μT, σ²T} + ϵ) where {∂yT, xT, μT, σ²T} ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²) ∂γ = γ === nothing ? nothing : similar(γ) ∂β = β === nothing ? nothing : similar(β) @@ -298,7 +298,7 @@ function ∇groupnorm_affine_normalize_cpu!( ∂x::AbstractArray{∂xT, 4}, ∂μ::AbstractArray{∂μT, 4}, ∂σ²::AbstractArray{∂σ²T, 4}, ::Nothing, ::Nothing, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, ::Nothing, - ϵ::Real) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT, μT, σ²T} + ϵ) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT, μT, σ²T} half = eltype(∂σ²)(0.5) fill!(∂μ, 0) @@ -340,7 +340,7 @@ function ∇groupnorm_affine_normalize_cpu!( ∂γ::AbstractArray{∂γT, 4}, ∂β::AbstractArray{∂βT, 4}, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::AbstractArray{γT, 4}, - ϵ::Real) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT, μT, σ²T, γT} + ϵ) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT, μT, σ²T, γT} half = eltype(∂σ²)(0.5) fill!(∂μ, 0) @@ -391,7 +391,7 @@ function ∇groupnorm_affine_normalize!( ∂γ::Optional{<:AbstractArray{<:Any, 4}}, ::GPUBroadcastOp, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {∂xT, ∂σ²T, ∂yT, xT, μT, σ²T} + ϵ) where {∂xT, ∂σ²T, ∂yT, xT, μT, σ²T} backend = KA.get_backend(∂x) run_ka_kernel( ∇groupnorm_affine_normalize_kernel!, backend, nothing, size(∂x), diff --git a/lib/LuxLib/src/impl/layernorm.jl b/lib/LuxLib/src/impl/layernorm.jl index 4655972670..6c27d1ac31 100644 --- a/lib/LuxLib/src/impl/layernorm.jl +++ b/lib/LuxLib/src/impl/layernorm.jl @@ -1,7 +1,7 @@ # TODO: For the `dims === nothing` case, we can optimize using a loop vectorization and # kernel abstractions function layernorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractArray}, - β::Optional{<:AbstractArray}, act::F, dims, epsilon::Real) where {N, F, xT} + β::Optional{<:AbstractArray}, act::F, dims, epsilon) where {N, F, xT} μ, σ² = mean_var(x; dims=compute_layernorm_dims(x, γ, β, dims), corrected=false) γ′, β′ = expand_layernorm_dims(x, γ, β, dims) return affine_normalize(act, x, μ, σ², γ′, β′, epsilon) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index c2c11f12ab..8fa64dda0b 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -1,14 +1,14 @@ # In most cases this implementation should not be preferred. But this is nice to have # because it works for arbitrary dimensions function affine_normalize(act::F, x::AbstractArray, μ::Numeric, σ²::Numeric, - ::Nothing, ::Nothing, ϵ::Real) where {F} + ::Nothing, ::Nothing, ϵ) where {F} γ′ = @. inv(sqrt(σ² + ϵ)) β′ = @. -μ * γ′ return @. act(x * γ′ + β′) end function affine_normalize(act::F, x::AbstractArray, μ::Numeric, σ²::Numeric, - γ::AbstractArray, β::AbstractArray, ϵ::Real) where {F} + γ::AbstractArray, β::AbstractArray, ϵ) where {F} γ′ = @. γ / sqrt(σ² + ϵ) β′ = @. β - μ * γ′ return @. act(x * γ′ + β′) @@ -69,7 +69,7 @@ end function update_normalization_statistics( x::AbstractArray{T, N}, rμ::AbstractArray{rμT, N}, rσ²::AbstractArray{rσ²T, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, - momentum::Real, reduce_dims) where {T, N, rμT, rσ²T, μT, σ²T} + momentum, reduce_dims) where {T, N, rμT, rσ²T, μT, σ²T} if last(reduce_dims) != N μ = mean(μ; dims=N) σ² = mean(σ²; dims=N) diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 403bc57fb5..df34c29520 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -49,7 +49,7 @@ LoopVectorization = "0.12.171" LuxTestUtils = "1.5" MKL = "0.7" MLDataDevices = "1.6" -NNlib = "0.9.21" +NNlib = "0.9.26" Octavian = "0.3.28" Pkg = "1.10" Random = "1.10" diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 58b6196c1a..8d30f4285d 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -21,7 +21,7 @@ function batchnorm_fallback( bias::LuxLib.Optional{<:AbstractVector}, running_mean::LuxLib.Optional{<:AbstractVector}, running_var::LuxLib.Optional{<:AbstractVector}, training::Val, - σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} + σ::F=identity, momentum=0.1f0, epsilon=1.0f-5) where {F, N} y, xm, xv = LuxLib.Impl.normalization(x, LuxLib.Utils.remove_tracking(running_mean), LuxLib.Utils.remove_tracking(running_var), scale, bias, LuxLib.Impl.batchnorm_reduce_dims(x), static(training), momentum, epsilon, σ) diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index c103595f99..f54a0ebf5e 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -16,7 +16,7 @@ end function groupnorm_fallback( x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, bias::LuxLib.Optional{<:AbstractVector}, groups::Int, - σ::F=identity, epsilon::Real=1.0f-5) where {F, N} + σ::F=identity, epsilon=1.0f-5) where {F, N} sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) y, _, _ = LuxLib.Impl.normalization(x_reshaped, nothing, nothing, scale, bias, diff --git a/src/Lux.jl b/src/Lux.jl index 29014cfa6f..64f0af07f1 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -44,6 +44,7 @@ include("utils.jl") include("extended_ops.jl") # Training Helpers +include("helpers/optimizers.jl") include("helpers/training.jl") # Experimental diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index 9a8f575b6b..00c4cae59d 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -92,7 +92,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, return CRC.rrule_via_ad(cfg, fallback_fused_agg, sum, op, x, y) end -get_ϵ(::Type{T}, ϵ::Real) where {T} = T(ϵ) +get_ϵ(::Type{T}, ϵ) where {T} = T(ϵ) get_ϵ(::Type{T}, ::Nothing) where {T} = eps(float(T)) get_loss_dims(::AbstractVector) = Colon() @@ -160,13 +160,13 @@ function msle_loss(x::T1, y::T2, ϵ) where {T1, T2} end label_smoothing(::Nothing, y, ::Type{T}) where {T} = y -function label_smoothing(label_smoothing::Real, y, ::Type{T}) where {T} +function label_smoothing(label_smoothing, y, ::Type{T}) where {T} label_smoothing = T(label_smoothing) return y .* (1 - label_smoothing) .+ label_smoothing ./ size(y, ndims(y) - 1) end label_smoothing_binary(::Nothing, y, ::Type{T}) where {T} = y -function label_smoothing_binary(label_smoothing::Real, y, ::Type{T}) where {T} +function label_smoothing_binary(label_smoothing, y, ::Type{T}) where {T} label_smoothing = T(label_smoothing) return y .* (1 - label_smoothing) .+ label_smoothing ./ 2 end @@ -725,7 +725,7 @@ true invariant mapping." 2006 IEEE computer society conference on computer vision and pattern recognition (CVPR'06). Vol. 2. IEEE, 2006. """ -function SiameseContrastiveLoss(; margin::Real=true, agg=mean) +function SiameseContrastiveLoss(; margin=true, agg=mean) @argcheck margin ≥ 0 return GenericLossFunction( Utils.Fix3(LossFunctionImpl.siamese_contrastive_loss, margin); agg) diff --git a/src/helpers/optimizers.jl b/src/helpers/optimizers.jl new file mode 100644 index 0000000000..fe0116bb4e --- /dev/null +++ b/src/helpers/optimizers.jl @@ -0,0 +1,184 @@ +# This is mostly an internal implementation detail that users shouldn't need to worry about. +# We can remove this once https://github.com/FluxML/Optimisers.jl/issues/205 is resolved. +module ReactantCompatibleOptimisers + +using ConcreteStructs: @concrete +using Optimisers: Optimisers, AbstractRule +using Setfield: Setfield, @set! + +using ..Lux: Lux, Utils + +abstract type ReactantCompatibleOptimisersRule <: AbstractRule end + +function make_reactant_compatible(opt::AbstractRule) + @warn "`make_reactant_compatible` is not defined for $(opt). Returning the original \ + optimizer. This means adjusting learning rate and other parameters won't \ + reflect in the generated MLIR." maxlog=1 + return opt +end +make_reactant_compatible(opt::ReactantCompatibleOptimisersRule) = opt + +function setfield_if_present(opt, field::Symbol, nt::NamedTuple) + if hasfield(typeof(nt), field) + opt = Setfield.set( + opt, Setfield.PropertyLens{field}(), + convert( + typeof(getproperty(opt, field)), + Utils.to_rarray(getproperty(nt, field); track_numbers=true) + ) + ) + end + return opt +end + +# OptimiserChain +function make_reactant_compatible(opt::Optimisers.OptimiserChain) + return Optimisers.OptimiserChain(make_reactant_compatible.(opt.opts)) +end + +# Descent +@concrete struct ReactantDescent <: ReactantCompatibleOptimisersRule + eta +end + +function make_reactant_compatible(opt::Optimisers.Descent) + return ReactantDescent(Utils.to_rarray(opt.eta; track_numbers=true)) +end + +Optimisers.init(::ReactantDescent, ::AbstractArray) = nothing + +function Optimisers.apply!(opt::ReactantDescent, state, x::AbstractArray{T}, dx) where {T} + η = T(opt.eta) + return state, @. dx * η +end + +function Optimisers._adjust(opt::ReactantDescent, nt::NamedTuple) + return setfield_if_present(opt, :eta, nt) +end + +# Momentum +@concrete struct ReactantMomentum <: ReactantCompatibleOptimisersRule + eta + rho +end + +function make_reactant_compatible(opt::Optimisers.Momentum) + return ReactantMomentum( + Utils.to_rarray(opt.eta; track_numbers=true), + Utils.to_rarray(opt.rho; track_numbers=true) + ) +end + +function Optimisers.init(::ReactantMomentum, x::AbstractArray) + return Optimisers.init(Optimisers.Momentum(0.0, 0.0), x) +end + +function Optimisers.apply!(opt::ReactantMomentum, mvel, ::AbstractArray{T}, dx) where {T} + η, ρ = T(opt.eta), T(opt.rho) + @. mvel = ρ * mvel + η * dx + return mvel, mvel +end + +function Optimisers._adjust(opt::ReactantMomentum, nt::NamedTuple) + opt = setfield_if_present(opt, :eta, nt) + opt = setfield_if_present(opt, :rho, nt) + return opt +end + +# Adam +@concrete struct ReactantAdam <: ReactantCompatibleOptimisersRule + eta + beta + epsilon +end + +function make_reactant_compatible(opt::Optimisers.Adam) + return ReactantAdam( + Utils.to_rarray(opt.eta; track_numbers=true), + Utils.to_rarray(opt.beta; track_numbers=true), + Utils.to_rarray(opt.epsilon; track_numbers=true) + ) +end + +function Optimisers.init(opt::ReactantAdam, x::AbstractArray{T}) where {T} + return ( + zero(x), + zero(x), + (Utils.promote_to(T, opt.beta[1]), Utils.promote_to(T, opt.beta[2])) + ) +end + +function Optimisers.apply!(o::ReactantAdam, state, ::AbstractArray{T}, dx) where {T} + η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) # XXX: See Optimisers._eps + mt, vt, βt = state + + @. mt = β[1] * mt + (1 - β[1]) * dx + @. vt = β[2] * vt + (1 - β[2]) * abs2(dx) + dx′ = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η + + return (mt, vt, βt .* β), dx′ +end + +function Optimisers._adjust(opt::ReactantAdam, nt::NamedTuple) + opt = setfield_if_present(opt, :eta, nt) + opt = setfield_if_present(opt, :beta, nt) + opt = setfield_if_present(opt, :epsilon, nt) + return opt +end + +# AdamW +@concrete struct ReactantAdamW <: ReactantCompatibleOptimisersRule + eta + beta + lambda + epsilon + couple::Bool +end + +function make_reactant_compatible(opt::Optimisers.AdamW) + return ReactantAdamW( + Utils.to_rarray(opt.eta; track_numbers=true), + Utils.to_rarray(opt.beta; track_numbers=true), + Utils.to_rarray(opt.lambda; track_numbers=true), + Utils.to_rarray(opt.epsilon; track_numbers=true), + opt.couple + ) +end + +function Optimisers.init(opt::ReactantAdamW, x::AbstractArray{T}) where {T} + return ( + zero(x), + zero(x), + (Utils.promote_to(T, opt.beta[1]), Utils.promote_to(T, opt.beta[2])) + ) +end + +function Optimisers.apply!(o::ReactantAdamW, state, x::AbstractArray{T}, dx) where {T} + η, β, ϵ, λ = T(o.eta), T.(o.beta), T(o.epsilon), T(o.lambda) # XXX: See Optimisers._eps + mt, vt, βt = state + + # standard Adam update with learning rate eta=1 + @. mt = β[1] * mt + (1 - β[1]) * dx + @. vt = β[2] * vt + (1 - β[2]) * abs2(dx) + dx′ = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η + + # apply learning rate and weight decay + if o.couple + dx′′ = @. η * (dx′ + λ * x) + else + dx′′ = @. η * dx′ + λ * x + end + + return (mt, vt, βt .* β), dx′′ +end + +function Optimisers._adjust(opt::ReactantAdamW, nt::NamedTuple) + opt = setfield_if_present(opt, :eta, nt) + opt = setfield_if_present(opt, :beta, nt) + opt = setfield_if_present(opt, :lambda, nt) + opt = setfield_if_present(opt, :epsilon, nt) + opt = setfield_if_present(opt, :couple, nt) + return opt +end + +end diff --git a/src/helpers/size_propagator.jl b/src/helpers/size_propagator.jl index fc0d12b78a..6e67453406 100644 --- a/src/helpers/size_propagator.jl +++ b/src/helpers/size_propagator.jl @@ -152,12 +152,12 @@ end function LuxLib.Impl.batchnorm( x::AnyNilArray{N}, ::Optional{<:AbstractVector}, ::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, - ::StaticBool, act::F, ::Real, ::Real) where {N, F} + ::StaticBool, act::F, ::Number, ::Number) where {N, F} return x, rμ, rσ² end function LuxLib.Impl.groupnorm(x::AnyNilArray{N}, ::Optional{<:AbstractVector}, - ::Optional{<:AbstractVector}, ::Int, act::F, ::Real) where {N, F} + ::Optional{<:AbstractVector}, ::Int, act::F, ::Number) where {N, F} return x end @@ -168,11 +168,11 @@ function LuxLib.Impl.normalization(x::AnyNilArray, rμ::Optional{<:AbstractVecto end function LuxLib.Impl.affine_normalize( - ::F, x::AnyNilArray, ::Numeric, ::Numeric, ::Nothing, ::Nothing, ::Real) where {F} + ::F, x::AnyNilArray, ::Numeric, ::Numeric, ::Nothing, ::Nothing, ::Number) where {F} return x end function LuxLib.Impl.affine_normalize(::F, x::AnyNilArray, ::Numeric, ::Numeric, - ::AbstractArray, ::AbstractArray, ::Real) where {F} + ::AbstractArray, ::AbstractArray, ::Number) where {F} return x end diff --git a/src/helpers/training.jl b/src/helpers/training.jl index da2b597a94..c11f74b93f 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -4,14 +4,14 @@ using ADTypes: AbstractADType, AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZyg using Compat: @compat using ConcreteStructs: @concrete using FastClosures: @closure -using Functors: fmap +using Functors: Functors, fmap using Optimisers: Optimisers using Setfield: @set! using Static: StaticBool, Static, False, True -using ..Lux: Lux, Utils +using ..Lux: Lux, Utils, ReactantCompatibleOptimisers using LuxCore: LuxCore, AbstractLuxLayer -using MLDataDevices: MLDataDevices, ReactantDevice, get_device_type, get_device, cpu_device +using MLDataDevices: MLDataDevices, ReactantDevice, get_device_type """ TrainState @@ -63,10 +63,10 @@ Constructor for [`TrainState`](@ref). [`TrainState`](@ref) object. """ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule) - dev = get_device(ps) - st_opt = if dev isa ReactantDevice - ps_cpu = ps |> cpu_device() - Optimisers.setup(optimizer, ps_cpu) |> dev + st_opt = if get_device_type(ps) <: ReactantDevice + Optimisers.setup( + ReactantCompatibleOptimisers.make_reactant_compatible(optimizer), ps + ) else Optimisers.setup(optimizer, ps) end diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index ac1e040082..91964b0dea 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -134,7 +134,9 @@ function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple) end function update_batchnorm_state(BN::BatchNorm, st::NamedTuple, stats) - has_track_stats(BN) && return merge(st, (; stats.running_mean, stats.running_var)) + has_track_stats(BN) && return merge(st, + (; running_mean=Utils.vec(stats.running_mean), + running_var=Utils.vec(stats.running_var))) return st end @@ -378,14 +380,23 @@ statelength(l::InstanceNorm) = ifelse(has_track_stats(l), l.chs * 2, 0) + 1 function (IN::InstanceNorm)(x::AbstractArray, ps, st::NamedTuple) x′ = match_eltype(IN, ps, st, x) σ = NNlib.fast_act(IN.activation, x′) - y, _ = instancenorm( + y, stats = instancenorm( x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), safe_getproperty(st, Val(:running_mean)), safe_getproperty(st, Val(:running_var)), st.training, σ, convert(unwrapped_eltype(x′), IN.momentum), convert(unwrapped_eltype(x′), IN.epsilon)) - return y, st + return y, update_instancenorm_state(IN, st, stats) end +function update_instancenorm_state(IN::InstanceNorm, st::NamedTuple, stats) + has_track_stats(IN) && return merge(st, + (; running_mean=Utils.vec(stats.running_mean), + running_var=Utils.vec(stats.running_var))) + return st +end + +CRC.@non_differentiable update_instancenorm_state(::Any...) + function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(l.chs)") (l.activation == identity) || print(io, ", $(l.activation)") diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl index 943eb947c1..3be3a5e24e 100644 --- a/src/layers/pooling.jl +++ b/src/layers/pooling.jl @@ -40,15 +40,23 @@ symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode abstract type AbstractPoolOp end struct MaxPoolOp <: AbstractPoolOp end + (m::MaxPoolOp)(x, pdims) = maxpool(x, pdims) +function (m::MaxPoolOp)(x, ::GlobalPoolMode) + return maximum(x; dims=1:(ndims(x) - 2), init=eltype(x)(-Inf)) +end struct MeanPoolOp <: AbstractPoolOp end + (m::MeanPoolOp)(x, pdims) = meanpool(x, pdims) +(m::MeanPoolOp)(x, ::GlobalPoolMode) = mean(x; dims=1:(ndims(x) - 2)) @concrete struct LpPoolOp <: AbstractPoolOp p end + (m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p) +(m::LpPoolOp)(x, ::GlobalPoolMode) = lpnormpool(x, PoolDims(x, size(x)[1:(end - 2)]); m.p) symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp() symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp() diff --git a/src/utils.jl b/src/utils.jl index 2a6930a2a8..99429f5c14 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -211,6 +211,9 @@ matrix_to_array(x::SMatrix{L, 1, T}, ::AbstractVector) where {L, T} = SVector{L, matrix_to_array(x::AbstractMatrix, ::AbstractMatrix) = x matrix_to_array(x::AbstractMatrix, y::AbstractArray) = reshape(x, :, size(y)[2:end]...) +function to_rarray end +function promote_to end + # This should probably be in WeightInitializers.jl calculate_gain(_, __) = 1.0f0 calculate_gain(::typeof(identity), _) = 1.0f0 @@ -222,7 +225,7 @@ calculate_gain(::typeof(NNlib.tanh_fast), _) = 5.0f0 / 3.0f0 function calculate_gain(::typeof(NNlib.leakyrelu), ::Nothing) return calculate_gain(NNlib.leakyrelu, 0.1f0) end -calculate_gain(::typeof(NNlib.leakyrelu), x::Real) = typeof(x)(√(2 / (1 + x^2))) +calculate_gain(::typeof(NNlib.leakyrelu), x) = typeof(x)(√(2 / (1 + x^2))) calculate_gain(::typeof(NNlib.selu), _) = 3.0f0 / 4 end diff --git a/test/Project.toml b/test/Project.toml index 58dd94c2ee..7f9cb93e5c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -62,7 +62,7 @@ LuxLib = "1.3.4" LuxTestUtils = "1.5" MLDataDevices = "1.6" MLUtils = "0.4.3" -NNlib = "0.9.24" +NNlib = "0.9.26" Octavian = "0.3.28" OneHotArrays = "0.2.5" Optimisers = "0.4.1" diff --git a/test/helpers/loss_tests.jl b/test/helpers/loss_tests.jl index ba79343140..3c0113b5c0 100644 --- a/test/helpers/loss_tests.jl +++ b/test/helpers/loss_tests.jl @@ -46,12 +46,12 @@ @testset "$mode" for (mode, aType, dev, ongpu) in MODES x = rand(10) |> aType - __f = sum ∘ Broadcast.BroadcastFunction(LuxOps.xlogx) - @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail=[AutoFiniteDiff()]) + @test_gradients(sum∘Broadcast.BroadcastFunction(LuxOps.xlogx), + x; atol=1.0f-3, rtol=1.0f-3, soft_fail=[AutoFiniteDiff()]) y = rand(10) |> aType - __f = sum ∘ Broadcast.BroadcastFunction(LuxOps.xlogy) - @test_gradients(__f, x, y; atol=1.0f-3, rtol=1.0f-3, soft_fail=[AutoFiniteDiff()]) + @test_gradients(sum∘Broadcast.BroadcastFunction(LuxOps.xlogy), + x, y; atol=1.0f-3, rtol=1.0f-3, soft_fail=[AutoFiniteDiff()]) end end @@ -79,8 +79,7 @@ end @jet loss_mean(ŷ, y) @jet loss_sum(ŷ, y) - __f = Base.Fix2(loss_mean, y) - @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(Base.Fix2(loss_mean, y), ŷ; atol=1.0f-3, rtol=1.0f-3) end @testset "MSLE" begin @@ -93,8 +92,7 @@ end @test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any broken=ongpu - __f = Base.Fix2(MSLELoss(), y) - @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(Base.Fix2(MSLELoss(), y), ŷ; atol=1.0f-3, rtol=1.0f-3) end end end @@ -203,9 +201,8 @@ end @test @inferred(Zygote.gradient(bceloss, σ.(logŷ), y)) isa Any - __f = Base.Fix2(bceloss, y) - σlogŷ = σ.(logŷ) - @test_gradients(__f, σlogŷ; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(Base.Fix2(bceloss, y), σ.(logŷ); atol=1.0f-3, rtol=1.0f-3, + enzyme_set_runtime_activity=true) end @testset "Logit BinaryCrossEntropyLoss" begin @@ -225,8 +222,8 @@ end @test @inferred(Zygote.gradient(logitbceloss, logŷ, y)) isa Any - __f = Base.Fix2(logitbceloss, y) - @test_gradients(__f, logŷ; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(Base.Fix2(logitbceloss, y), logŷ; atol=1.0f-3, rtol=1.0f-3, + enzyme_set_runtime_activity=true) end @testset "BinaryFocalLoss" begin @@ -248,8 +245,7 @@ end @test @inferred(Zygote.gradient(BinaryFocalLoss(), ŷ, y)) isa Any broken=ongpu - __f = Base.Fix2(BinaryFocalLoss(), y) - @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(Base.Fix2(BinaryFocalLoss(), y), ŷ; atol=1.0f-3, rtol=1.0f-3) end @testset "FocalLoss" begin diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl index 6b545e15b6..ec3704d90e 100644 --- a/test/layers/normalize_tests.jl +++ b/test/layers/normalize_tests.jl @@ -56,7 +56,7 @@ @jet m(x, ps, Lux.testmode(st)) @test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3, - rtol=1.0f-3, skip_backends=[AutoFiniteDiff()], broken_backends) + rtol=1.0f-3, skip_backends=[AutoFiniteDiff()]) # with activation function m = BatchNorm(2, sigmoid; affine) diff --git a/test/reactant/training_tests.jl b/test/reactant/training_tests.jl index 06c550a3ba..6384cd49a6 100644 --- a/test/reactant/training_tests.jl +++ b/test/reactant/training_tests.jl @@ -18,7 +18,9 @@ @testset "MLP Training: $(version)" for version in (:iip, :oop) model = Chain( Dense(2 => 32, gelu), + BatchNorm(32), Dense(32 => 32, gelu), + BatchNorm(32), Dense(32 => 2) ) ps, st = Lux.setup(StableRNG(1234), model) |> xdev @@ -43,27 +45,31 @@ inference_loss_fn_compiled(xᵢ, yᵢ, model, ps, st) end - train_state = Training.TrainState(model, ps, st, Adam(0.01f0)) + @testset for opt in ( + Descent(0.01f0), Momentum(0.01f0), Adam(0.01f0), AdamW(0.01f0) + ) + train_state = Training.TrainState(model, ps, st, opt) - for epoch in 1:100, (xᵢ, yᵢ) in dataloader - grads, loss, stats, train_state = if version === :iip - Training.single_train_step!( - AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state) - elseif version === :oop - Training.single_train_step( - AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state) - else - error("Invalid version: $(version)") + for epoch in 1:100, (xᵢ, yᵢ) in dataloader + grads, loss, stats, train_state = if version === :iip + Training.single_train_step!( + AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state) + elseif version === :oop + Training.single_train_step( + AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state) + else + error("Invalid version: $(version)") + end end - end - total_final_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ) - inference_loss_fn_compiled( - xᵢ, yᵢ, model, train_state.parameters, train_state.states - ) - end + total_final_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ) + inference_loss_fn_compiled( + xᵢ, yᵢ, model, train_state.parameters, train_state.states + ) + end - @test total_final_loss < 100 * total_initial_loss + @test total_final_loss < 100 * total_initial_loss + end end end end diff --git a/test/runtests.jl b/test/runtests.jl index 0f96e8b49f..91db71bcb2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -127,9 +127,7 @@ const RETESTITEMS_NWORKER_THREADS = parse( string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1)))) @testset "Lux.jl Tests" begin - for (i, tag) in enumerate(LUX_TEST_GROUP) - @info "Running tests for group: [$(i)/$(length(LUX_TEST_GROUP))] $tag" - + @testset "[$(tag)] [$(i)/$(length(LUX_TEST_GROUP))]" for (i, tag) in enumerate(LUX_TEST_GROUP) nworkers = (tag == "reactant") || (BACKEND_GROUP == "amdgpu") ? 0 : RETESTITEMS_NWORKERS diff --git a/test/setup_modes.jl b/test/setup_modes.jl index 1617179a5b..b7c581ccca 100644 --- a/test/setup_modes.jl +++ b/test/setup_modes.jl @@ -1,4 +1,4 @@ -using Lux, MLDataDevices +using Lux, MLDataDevices, Pkg if !@isdefined(BACKEND_GROUP) const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all"))