From 896675c33bcdce35b2227785fec03185e4f45ffa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 28 Dec 2024 13:41:53 -0500 Subject: [PATCH] fix: don't force ::Real --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibTrackerExt.jl | 2 +- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 2 +- lib/LuxLib/src/api/batchnorm.jl | 2 +- lib/LuxLib/src/api/groupnorm.jl | 4 +- lib/LuxLib/src/api/instancenorm.jl | 4 +- lib/LuxLib/src/api/layernorm.jl | 2 +- lib/LuxLib/src/impl/batchnorm.jl | 26 ++++++------- lib/LuxLib/src/impl/dropout.jl | 14 +++---- lib/LuxLib/src/impl/groupnorm.jl | 38 +++++++++---------- lib/LuxLib/src/impl/layernorm.jl | 2 +- lib/LuxLib/src/impl/normalization.jl | 6 +-- .../test/normalization/batchnorm_tests.jl | 2 +- .../test/normalization/groupnorm_tests.jl | 2 +- src/helpers/losses.jl | 8 ++-- src/helpers/size_propagator.jl | 8 ++-- src/utils.jl | 2 +- 17 files changed, 63 insertions(+), 63 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index aa927150be..9966f52fa5 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.10" +version = "1.3.11" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index d7b0225937..a7234c5eb4 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -97,7 +97,7 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), Utils.is_tracked(RM, RV, S, B, XT) || continue @eval Tracker.@grad_from_chainrules LuxLib.Impl.batchnorm_cudnn( - γ::$S, β::$B, x::$XT, rμ::$RM, rσ²::$RV, m::Real, ϵ::Real, training::StaticBool) + γ::$S, β::$B, x::$XT, rμ::$RM, rσ²::$RV, m, ϵ, training::StaticBool) end # Utils extensions diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 77e59d3e4b..b35af417bc 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -21,7 +21,7 @@ include("batchnorm.jl") function Impl.batchnorm(x::Union{<:CuArray{T, 2}, <:CuArray{T, 4}, <:CuArray{T, 5}}, γ::Optional{<:CuVector{T}}, β::Optional{<:CuVector{T}}, rμ::Optional{<:CuVector{T}}, rσ²::Optional{<:CuVector{T}}, - training::StaticBool, σ::F, m::Real, ϵ::Real) where {T <: cuDNNFloat, F} + training::StaticBool, σ::F, m, ϵ) where {T <: cuDNNFloat, F} rμₙ, rσ²ₙ = Impl.get_batchnorm_statistics(x, rμ, rσ², training) y = Impl.batchnorm_cudnn(γ, β, x, rμₙ, rσ²ₙ, m, ϵ, training)[1] return Impl.activation!!(σ, y), safe_vec(rμₙ), safe_vec(rσ²ₙ) diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 05964f0c6b..bba8a5af27 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -37,7 +37,7 @@ mean and variance. function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, training::TrainingType, act::F=identity, - momentum::Real=0.1f0, epsilon::Real=default_epsilon(x)) where {F, T, N} + momentum=0.1f0, epsilon=default_epsilon(x)) where {F, T, N} σ = select_fastest_activation(act, x, γ, β, rμ, rσ²) y, rμ, rσ² = batchnorm_impl( x, γ, β, rμ, rσ², static_training_mode(training, x, γ, β, rμ, rσ²), diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 4e6a7bff86..1053ff9dfe 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -1,6 +1,6 @@ @doc doc""" groupnorm(x, scale, bias, groups::Int, σ::F=identity, - epsilon::Real=eps(eltype(x)) ^ (5 // 7)) + epsilon=eps(eltype(x)) ^ (5 // 7)) Group Normalization. For details see [1]. @@ -30,7 +30,7 @@ The normalized array is returned. """ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity, - epsilon::Real=default_epsilon(x)) where {F, N} + epsilon=default_epsilon(x)) where {F, N} assert_valid_groupnorm_arguments(x, scale, bias, groups) return groupnorm_impl( x, scale, bias, groups, select_fastest_activation(σ, x, scale, bias), epsilon) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 1587855242..259e14bb4e 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -36,7 +36,7 @@ mean and variance. """ function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, training::TrainingType, - σ::F=identity, epsilon::Real=default_epsilon(x)) where {F} + σ::F=identity, epsilon=default_epsilon(x)) where {F} # This API is kept for legacy purposes when we didn't support passing running stats return instancenorm(x, γ, β, nothing, nothing, training, σ, nothing, epsilon) end @@ -44,7 +44,7 @@ end function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, training::TrainingType, σ::F=identity, - momentum::Optional{<:Real}=0.1f0, epsilon::Real=default_epsilon(x)) where {F} + momentum::Optional{<:Real}=0.1f0, epsilon=default_epsilon(x)) where {F} assert_valid_instancenorm_arguments(x) y, rμₙ, rσ²ₙ = instancenorm_impl( diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index eb147d30ef..7148fbb0da 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -36,7 +36,7 @@ Normalized Array of same size as `x`. """ function layernorm(x::AbstractArray{xT, N}, scale::Optional{<:AbstractArray}, bias::Optional{<:AbstractArray}, σ::F=identity, dims=1:(N - 1), - epsilon::Real=default_epsilon(x)) where {F, xT, N} + epsilon=default_epsilon(x)) where {F, xT, N} return layernorm_impl( x, scale, bias, select_fastest_activation(σ, x, scale, bias), dims, epsilon) end diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index 995aacf857..d37bee3464 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -27,7 +27,7 @@ CRC.@non_differentiable get_batchnorm_statistics(::Any...) function batchnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, training::StaticBool, act::F, - momentum::Real, ϵ::Real) where {F, xT, N} + momentum, ϵ) where {F, xT, N} (μ, σ²), (rμ, rσ²) = compute_batch_statistics( x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²), batchnorm_reduce_dims(x), training, momentum) @@ -37,7 +37,7 @@ end function batchnorm_affine_normalize( act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N} + β::Optional{<:AbstractVector}, ϵ) where {F, xT, μT, σ²T, N} return batchnorm_affine_normalize( internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) end @@ -45,7 +45,7 @@ end function batchnorm_affine_normalize( ::GenericBroadcastOp, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N} + β::Optional{<:AbstractVector}, ϵ) where {F, xT, μT, σ²T, N} return affine_normalize( act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) end @@ -54,7 +54,7 @@ function batchnorm_affine_normalize( opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, - ϵ::Real) where {F, xT, μT, σ²T, N} + ϵ) where {F, xT, μT, σ²T, N} x′ = reshape(x, :, size(x, N - 1), size(x, N)) return reshape( batchnorm_affine_normalize_internal(opmode, act, x′, vec(μ), vec(σ²), γ, β, ϵ), @@ -64,7 +64,7 @@ end @stable default_mode="disable" function batchnorm_affine_normalize_internal( opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT} + β::Optional{<:AbstractVector}, ϵ) where {F, xT} y = similar(x, promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β))) @@ -75,7 +75,7 @@ end function batchnorm_affine_normalize_internal!( y::AbstractArray{yT, 3}, opmode::LoopedArrayOp, act::F, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real, + β::Optional{<:AbstractVector}, ϵ, γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT} N = size(y, 2) γ′ = γ′ === nothing ? @@ -225,7 +225,7 @@ end function batchnorm_affine_normalize_internal!( y::AbstractArray{yT, 3}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real, + β::Optional{<:AbstractVector}, ϵ, γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT} backend = KA.get_backend(y) run_ka_kernel( @@ -278,7 +278,7 @@ function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(batchnorm_affine_normalize_internal), opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{T, N}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} + β::Optional{<:AbstractVector}, ϵ) where {F, T, N} y = similar(x, promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β))) @@ -304,7 +304,7 @@ end function ∇batchnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, - γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ, γ′::AbstractVector) where {∂yT, xT} ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²) ∂γ = γ === nothing ? nothing : similar(γ) @@ -322,7 +322,7 @@ function ∇batchnorm_affine_normalize_cpu!( ∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT}, ∂σ²::AbstractVector{∂σ²T}, ::Nothing, ::Nothing, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, ::Nothing, - ϵ::Real, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT} + ϵ, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT} half = eltype(∂σ²)(0.5) fill!(∂μ, 0) @@ -361,7 +361,7 @@ function ∇batchnorm_affine_normalize_cpu!( ∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT}, ∂σ²::AbstractVector{∂σ²T}, ∂γ::AbstractVector{∂γT}, ∂β::AbstractVector{∂βT}, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, - μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ::Real, + μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT} half = eltype(∂σ²)(0.5) @@ -406,7 +406,7 @@ end function ∇batchnorm_affine_normalize( opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, - γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ, γ′::AbstractVector) where {∂yT, xT} ∂x, ∂σ² = similar(x), similar(σ², size(x)) ∂γ = γ === nothing ? nothing : similar(γ, size(x)) @@ -425,7 +425,7 @@ function ∇batchnorm_affine_normalize!( ∂x::AbstractArray{∂xT, 3}, ∂σ²::AbstractArray{∂σ²T, 3}, ∂γ::Optional{<:AbstractArray{<:Any, 3}}, ::GPUBroadcastOp, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector, - σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real, + σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ, γ′::AbstractVector) where {∂xT, ∂σ²T, ∂yT, xT} backend = KA.get_backend(∂x) run_ka_kernel( diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 5b4248291f..10dda2f69e 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -62,22 +62,22 @@ function alpha_dropout(noise::AbstractArray, p, x::AbstractArray, α, A, B) end @stable default_mode="disable" function alpha_dropout( - ::AbstractInternalArrayOpMode, noise::AbstractArray, p::Real, - x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T} + ::AbstractInternalArrayOpMode, noise::AbstractArray, p, + x::AbstractArray{T}, α, A, B) where {T} A′, B′, α = T(A), T(B), T(α) return @. muladd(ifelse(noise > p, x, α), A′, B′) end @stable default_mode="disable" function alpha_dropout( - opmode::LoopedArrayOp, noise::AbstractArray, p::Real, - x::AbstractArray, α::Real, A::Real, B::Real) + opmode::LoopedArrayOp, noise::AbstractArray, p, + x::AbstractArray, α, A, B) res = similar(x, promote_type(typeof(p), typeof(α))) alpha_dropout!(res, opmode, noise, p, x, α, A, B) return res end function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArray, - p::Real, x::AbstractArray, α::Real, A::Real, B::Real) + p, x::AbstractArray, α, A, B) cond = similar(noise, Bool) y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) @simd ivdep for I in eachindex(noise, x, y, cond) @@ -99,7 +99,7 @@ function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArra end function CRC.rrule(::typeof(alpha_dropout), ::AbstractInternalArrayOpMode, - noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) + noise::AbstractArray, p, x::AbstractArray, α, A, B) cond = noise .> p y = @. ifelse(cond, x, α) * A + B @@ -114,7 +114,7 @@ end function alpha_dropout!( res::AbstractArray{T}, ::LoopedArrayOp, noise::AbstractArray{T}, - p::Real, x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T} + p, x::AbstractArray{T}, α, A, B) where {T} @simd ivdep for I in eachindex(noise, x, res) res[I] = ifelse(noise[I] > p, x[I], α) * A + B end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 9a64fd7350..df52b8508b 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -3,7 +3,7 @@ groupnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} = ntuple(static, N - 1 CRC.@non_differentiable groupnorm_reduce_dims(::Any) function groupnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, groups::Int, act::F, ϵ::Real) where {F, N, xT} + β::Optional{<:AbstractVector}, groups::Int, act::F, ϵ) where {F, N, xT} x′ = reshape(x, size(x)[1:(N - 2)]..., size(x, N - 1) ÷ groups, groups, size(x, N)) (μ, σ²), _ = compute_batch_statistics( x′, nothing, nothing, groupnorm_reduce_dims(x), False(), nothing) @@ -13,7 +13,7 @@ end function groupnorm_affine_normalize( act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T} + β::Optional{<:AbstractVector}, ϵ) where {F, N, xT, μT, σ²T} return groupnorm_affine_normalize( internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) end @@ -21,7 +21,7 @@ end function groupnorm_affine_normalize( ::GenericBroadcastOp, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T} + β::Optional{<:AbstractVector}, ϵ) where {F, N, xT, μT, σ²T} return affine_normalize( act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) end @@ -29,7 +29,7 @@ end @generated function groupnorm_affine_normalize( opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T} + β::Optional{<:AbstractVector}, ϵ) where {F, N, xT, μT, σ²T} reshape_calls = if γ != Nothing quote γ′ = reshape(γ, 1, size(x, N - 2), size(x, N - 1), 1) @@ -57,7 +57,7 @@ end opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {F, xT, μT, σ²T} + ϵ) where {F, xT, μT, σ²T} y = similar(x, promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β))) @@ -69,7 +69,7 @@ function groupnorm_affine_normalize_internal!( y::AbstractArray{yT, 4}, opmode::LoopedArrayOp, act::F, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {F, xT, yT, μT, σ²T} + ϵ) where {F, xT, yT, μT, σ²T} if unsafe_known(fuse_cpu_activation(act)) groupnorm_affine_normalize_act_cpu!(y, x, μ, σ², γ, β, ϵ, act) else @@ -82,7 +82,7 @@ end function groupnorm_affine_normalize_act_cpu!( y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, act::F) where {F, xT, yT, μT, σ²T} + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ, act::F) where {F, xT, yT, μT, σ²T} if size(y, 1) == 1 groupnorm_affine_normalize_act_3d_serial_cpu!(y, x, μ, σ², γ, β, ϵ, act) else @@ -93,7 +93,7 @@ end function groupnorm_affine_normalize_act_3d_serial_cpu!( y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T} + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ, σ::F) where {F, xT, yT, μT, σ²T} if γ === nothing && β === nothing @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -117,7 +117,7 @@ end function groupnorm_affine_normalize_act_4d_serial_cpu!( y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T} + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ, σ::F) where {F, xT, yT, μT, σ²T} if γ === nothing && β === nothing @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -145,7 +145,7 @@ end function groupnorm_affine_normalize_cpu!( y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ) where {xT, yT, μT, σ²T} if size(y, 1) == 1 groupnorm_affine_normalize_3d_serial_cpu!(y, x, μ, σ², γ, β, ϵ) else @@ -156,7 +156,7 @@ end @inline function groupnorm_affine_normalize_3d_serial_cpu!( y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ) where {xT, yT, μT, σ²T} if γ === nothing && β === nothing @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -180,7 +180,7 @@ end @inline function groupnorm_affine_normalize_4d_serial_cpu!( y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ) where {xT, yT, μT, σ²T} if γ === nothing && β === nothing @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -209,7 +209,7 @@ function groupnorm_affine_normalize_internal!( y::AbstractArray{yT, 4}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {F, xT, yT, μT, σ²T} + ϵ) where {F, xT, yT, μT, σ²T} backend = KA.get_backend(y) run_ka_kernel( groupnorm_affine_normalize_kernel!, backend, nothing, size(y), @@ -240,7 +240,7 @@ function CRC.rrule( opmode::AbstractInternalArrayOpMode, f::F, x::AbstractArray{T, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {F, T, μT, σ²T} + ϵ) where {F, T, μT, σ²T} y = similar(x, promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β))) @@ -264,7 +264,7 @@ function ∇groupnorm_affine_normalize( opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {∂yT, xT, μT, σ²T} + ϵ) where {∂yT, xT, μT, σ²T} ∂x, ∂σ² = similar(x), similar(σ², size(x)) ∂γ = γ === nothing ? nothing : similar(γ, size(x)) @@ -281,7 +281,7 @@ end function ∇groupnorm_affine_normalize(::LoopedArrayOp, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {∂yT, xT, μT, σ²T} + ϵ) where {∂yT, xT, μT, σ²T} ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²) ∂γ = γ === nothing ? nothing : similar(γ) ∂β = β === nothing ? nothing : similar(β) @@ -298,7 +298,7 @@ function ∇groupnorm_affine_normalize_cpu!( ∂x::AbstractArray{∂xT, 4}, ∂μ::AbstractArray{∂μT, 4}, ∂σ²::AbstractArray{∂σ²T, 4}, ::Nothing, ::Nothing, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, ::Nothing, - ϵ::Real) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT, μT, σ²T} + ϵ) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT, μT, σ²T} half = eltype(∂σ²)(0.5) fill!(∂μ, 0) @@ -340,7 +340,7 @@ function ∇groupnorm_affine_normalize_cpu!( ∂γ::AbstractArray{∂γT, 4}, ∂β::AbstractArray{∂βT, 4}, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::AbstractArray{γT, 4}, - ϵ::Real) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT, μT, σ²T, γT} + ϵ) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT, μT, σ²T, γT} half = eltype(∂σ²)(0.5) fill!(∂μ, 0) @@ -391,7 +391,7 @@ function ∇groupnorm_affine_normalize!( ∂γ::Optional{<:AbstractArray{<:Any, 4}}, ::GPUBroadcastOp, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, - ϵ::Real) where {∂xT, ∂σ²T, ∂yT, xT, μT, σ²T} + ϵ) where {∂xT, ∂σ²T, ∂yT, xT, μT, σ²T} backend = KA.get_backend(∂x) run_ka_kernel( ∇groupnorm_affine_normalize_kernel!, backend, nothing, size(∂x), diff --git a/lib/LuxLib/src/impl/layernorm.jl b/lib/LuxLib/src/impl/layernorm.jl index 4655972670..6c27d1ac31 100644 --- a/lib/LuxLib/src/impl/layernorm.jl +++ b/lib/LuxLib/src/impl/layernorm.jl @@ -1,7 +1,7 @@ # TODO: For the `dims === nothing` case, we can optimize using a loop vectorization and # kernel abstractions function layernorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractArray}, - β::Optional{<:AbstractArray}, act::F, dims, epsilon::Real) where {N, F, xT} + β::Optional{<:AbstractArray}, act::F, dims, epsilon) where {N, F, xT} μ, σ² = mean_var(x; dims=compute_layernorm_dims(x, γ, β, dims), corrected=false) γ′, β′ = expand_layernorm_dims(x, γ, β, dims) return affine_normalize(act, x, μ, σ², γ′, β′, epsilon) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index c2c11f12ab..8fa64dda0b 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -1,14 +1,14 @@ # In most cases this implementation should not be preferred. But this is nice to have # because it works for arbitrary dimensions function affine_normalize(act::F, x::AbstractArray, μ::Numeric, σ²::Numeric, - ::Nothing, ::Nothing, ϵ::Real) where {F} + ::Nothing, ::Nothing, ϵ) where {F} γ′ = @. inv(sqrt(σ² + ϵ)) β′ = @. -μ * γ′ return @. act(x * γ′ + β′) end function affine_normalize(act::F, x::AbstractArray, μ::Numeric, σ²::Numeric, - γ::AbstractArray, β::AbstractArray, ϵ::Real) where {F} + γ::AbstractArray, β::AbstractArray, ϵ) where {F} γ′ = @. γ / sqrt(σ² + ϵ) β′ = @. β - μ * γ′ return @. act(x * γ′ + β′) @@ -69,7 +69,7 @@ end function update_normalization_statistics( x::AbstractArray{T, N}, rμ::AbstractArray{rμT, N}, rσ²::AbstractArray{rσ²T, N}, μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, - momentum::Real, reduce_dims) where {T, N, rμT, rσ²T, μT, σ²T} + momentum, reduce_dims) where {T, N, rμT, rσ²T, μT, σ²T} if last(reduce_dims) != N μ = mean(μ; dims=N) σ² = mean(σ²; dims=N) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 58b6196c1a..8d30f4285d 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -21,7 +21,7 @@ function batchnorm_fallback( bias::LuxLib.Optional{<:AbstractVector}, running_mean::LuxLib.Optional{<:AbstractVector}, running_var::LuxLib.Optional{<:AbstractVector}, training::Val, - σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} + σ::F=identity, momentum=0.1f0, epsilon=1.0f-5) where {F, N} y, xm, xv = LuxLib.Impl.normalization(x, LuxLib.Utils.remove_tracking(running_mean), LuxLib.Utils.remove_tracking(running_var), scale, bias, LuxLib.Impl.batchnorm_reduce_dims(x), static(training), momentum, epsilon, σ) diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index c103595f99..f54a0ebf5e 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -16,7 +16,7 @@ end function groupnorm_fallback( x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, bias::LuxLib.Optional{<:AbstractVector}, groups::Int, - σ::F=identity, epsilon::Real=1.0f-5) where {F, N} + σ::F=identity, epsilon=1.0f-5) where {F, N} sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) y, _, _ = LuxLib.Impl.normalization(x_reshaped, nothing, nothing, scale, bias, diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index 9a8f575b6b..00c4cae59d 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -92,7 +92,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, return CRC.rrule_via_ad(cfg, fallback_fused_agg, sum, op, x, y) end -get_ϵ(::Type{T}, ϵ::Real) where {T} = T(ϵ) +get_ϵ(::Type{T}, ϵ) where {T} = T(ϵ) get_ϵ(::Type{T}, ::Nothing) where {T} = eps(float(T)) get_loss_dims(::AbstractVector) = Colon() @@ -160,13 +160,13 @@ function msle_loss(x::T1, y::T2, ϵ) where {T1, T2} end label_smoothing(::Nothing, y, ::Type{T}) where {T} = y -function label_smoothing(label_smoothing::Real, y, ::Type{T}) where {T} +function label_smoothing(label_smoothing, y, ::Type{T}) where {T} label_smoothing = T(label_smoothing) return y .* (1 - label_smoothing) .+ label_smoothing ./ size(y, ndims(y) - 1) end label_smoothing_binary(::Nothing, y, ::Type{T}) where {T} = y -function label_smoothing_binary(label_smoothing::Real, y, ::Type{T}) where {T} +function label_smoothing_binary(label_smoothing, y, ::Type{T}) where {T} label_smoothing = T(label_smoothing) return y .* (1 - label_smoothing) .+ label_smoothing ./ 2 end @@ -725,7 +725,7 @@ true invariant mapping." 2006 IEEE computer society conference on computer vision and pattern recognition (CVPR'06). Vol. 2. IEEE, 2006. """ -function SiameseContrastiveLoss(; margin::Real=true, agg=mean) +function SiameseContrastiveLoss(; margin=true, agg=mean) @argcheck margin ≥ 0 return GenericLossFunction( Utils.Fix3(LossFunctionImpl.siamese_contrastive_loss, margin); agg) diff --git a/src/helpers/size_propagator.jl b/src/helpers/size_propagator.jl index fc0d12b78a..6e67453406 100644 --- a/src/helpers/size_propagator.jl +++ b/src/helpers/size_propagator.jl @@ -152,12 +152,12 @@ end function LuxLib.Impl.batchnorm( x::AnyNilArray{N}, ::Optional{<:AbstractVector}, ::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, - ::StaticBool, act::F, ::Real, ::Real) where {N, F} + ::StaticBool, act::F, ::Number, ::Number) where {N, F} return x, rμ, rσ² end function LuxLib.Impl.groupnorm(x::AnyNilArray{N}, ::Optional{<:AbstractVector}, - ::Optional{<:AbstractVector}, ::Int, act::F, ::Real) where {N, F} + ::Optional{<:AbstractVector}, ::Int, act::F, ::Number) where {N, F} return x end @@ -168,11 +168,11 @@ function LuxLib.Impl.normalization(x::AnyNilArray, rμ::Optional{<:AbstractVecto end function LuxLib.Impl.affine_normalize( - ::F, x::AnyNilArray, ::Numeric, ::Numeric, ::Nothing, ::Nothing, ::Real) where {F} + ::F, x::AnyNilArray, ::Numeric, ::Numeric, ::Nothing, ::Nothing, ::Number) where {F} return x end function LuxLib.Impl.affine_normalize(::F, x::AnyNilArray, ::Numeric, ::Numeric, - ::AbstractArray, ::AbstractArray, ::Real) where {F} + ::AbstractArray, ::AbstractArray, ::Number) where {F} return x end diff --git a/src/utils.jl b/src/utils.jl index eb24d2e25d..99429f5c14 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -225,7 +225,7 @@ calculate_gain(::typeof(NNlib.tanh_fast), _) = 5.0f0 / 3.0f0 function calculate_gain(::typeof(NNlib.leakyrelu), ::Nothing) return calculate_gain(NNlib.leakyrelu, 0.1f0) end -calculate_gain(::typeof(NNlib.leakyrelu), x::Real) = typeof(x)(√(2 / (1 + x^2))) +calculate_gain(::typeof(NNlib.leakyrelu), x) = typeof(x)(√(2 / (1 + x^2))) calculate_gain(::typeof(NNlib.selu), _) = 3.0f0 / 4 end