Skip to content

Commit

Permalink
fix: don't force ::Real
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 28, 2024
1 parent cda5278 commit 896675c
Show file tree
Hide file tree
Showing 17 changed files with 63 additions and 63 deletions.
2 changes: 1 addition & 1 deletion lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.3.10"
version = "1.3.11"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/ext/LuxLibTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector),
Utils.is_tracked(RM, RV, S, B, XT) || continue

@eval Tracker.@grad_from_chainrules LuxLib.Impl.batchnorm_cudnn(
γ::$S, β::$B, x::$XT, rμ::$RM, rσ²::$RV, m::Real, ϵ::Real, training::StaticBool)
γ::$S, β::$B, x::$XT, rμ::$RM, rσ²::$RV, m, ϵ, training::StaticBool)
end

# Utils extensions
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ include("batchnorm.jl")
function Impl.batchnorm(x::Union{<:CuArray{T, 2}, <:CuArray{T, 4}, <:CuArray{T, 5}},
γ::Optional{<:CuVector{T}}, β::Optional{<:CuVector{T}},
::Optional{<:CuVector{T}}, rσ²::Optional{<:CuVector{T}},
training::StaticBool, σ::F, m::Real, ϵ::Real) where {T <: cuDNNFloat, F}
training::StaticBool, σ::F, m, ϵ) where {T <: cuDNNFloat, F}
rμₙ, rσ²ₙ = Impl.get_batchnorm_statistics(x, rμ, rσ², training)
y = Impl.batchnorm_cudnn(γ, β, x, rμₙ, rσ²ₙ, m, ϵ, training)[1]
return Impl.activation!!(σ, y), safe_vec(rμₙ), safe_vec(rσ²ₙ)
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/src/api/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ mean and variance.
function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector},
β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector},
rσ²::Optional{<:AbstractVector}, training::TrainingType, act::F=identity,
momentum::Real=0.1f0, epsilon::Real=default_epsilon(x)) where {F, T, N}
momentum=0.1f0, epsilon=default_epsilon(x)) where {F, T, N}
σ = select_fastest_activation(act, x, γ, β, rμ, rσ²)
y, rμ, rσ² = batchnorm_impl(
x, γ, β, rμ, rσ², static_training_mode(training, x, γ, β, rμ, rσ²),
Expand Down
4 changes: 2 additions & 2 deletions lib/LuxLib/src/api/groupnorm.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@doc doc"""
groupnorm(x, scale, bias, groups::Int, σ::F=identity,
epsilon::Real=eps(eltype(x)) ^ (5 // 7))
epsilon=eps(eltype(x)) ^ (5 // 7))
Group Normalization. For details see [1].
Expand Down Expand Up @@ -30,7 +30,7 @@ The normalized array is returned.
"""
function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity,
epsilon::Real=default_epsilon(x)) where {F, N}
epsilon=default_epsilon(x)) where {F, N}
assert_valid_groupnorm_arguments(x, scale, bias, groups)
return groupnorm_impl(
x, scale, bias, groups, select_fastest_activation(σ, x, scale, bias), epsilon)
Expand Down
4 changes: 2 additions & 2 deletions lib/LuxLib/src/api/instancenorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ mean and variance.
"""
function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector},
β::Optional{<:AbstractVector}, training::TrainingType,
σ::F=identity, epsilon::Real=default_epsilon(x)) where {F}
σ::F=identity, epsilon=default_epsilon(x)) where {F}
# This API is kept for legacy purposes when we didn't support passing running stats
return instancenorm(x, γ, β, nothing, nothing, training, σ, nothing, epsilon)
end

function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector},
β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector},
rσ²::Optional{<:AbstractVector}, training::TrainingType, σ::F=identity,
momentum::Optional{<:Real}=0.1f0, epsilon::Real=default_epsilon(x)) where {F}
momentum::Optional{<:Real}=0.1f0, epsilon=default_epsilon(x)) where {F}
assert_valid_instancenorm_arguments(x)

y, rμₙ, rσ²ₙ = instancenorm_impl(
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/src/api/layernorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Normalized Array of same size as `x`.
"""
function layernorm(x::AbstractArray{xT, N}, scale::Optional{<:AbstractArray},
bias::Optional{<:AbstractArray}, σ::F=identity, dims=1:(N - 1),
epsilon::Real=default_epsilon(x)) where {F, xT, N}
epsilon=default_epsilon(x)) where {F, xT, N}
return layernorm_impl(
x, scale, bias, select_fastest_activation(σ, x, scale, bias), dims, epsilon)
end
26 changes: 13 additions & 13 deletions lib/LuxLib/src/impl/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ CRC.@non_differentiable get_batchnorm_statistics(::Any...)
function batchnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector},
β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector},
rσ²::Optional{<:AbstractVector}, training::StaticBool, act::F,
momentum::Real, ϵ::Real) where {F, xT, N}
momentum, ϵ) where {F, xT, N}
(μ, σ²), (rμ, rσ²) = compute_batch_statistics(
x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²),
batchnorm_reduce_dims(x), training, momentum)
Expand All @@ -37,15 +37,15 @@ end
function batchnorm_affine_normalize(
act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N},
σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector},
β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N}
β::Optional{<:AbstractVector}, ϵ) where {F, xT, μT, σ²T, N}
return batchnorm_affine_normalize(
internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ)
end

function batchnorm_affine_normalize(
::GenericBroadcastOp, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N},
σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector},
β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N}
β::Optional{<:AbstractVector}, ϵ) where {F, xT, μT, σ²T, N}
return affine_normalize(
act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ)
end
Expand All @@ -54,7 +54,7 @@ function batchnorm_affine_normalize(
opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, N},
μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N},
γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector},
ϵ::Real) where {F, xT, μT, σ²T, N}
ϵ) where {F, xT, μT, σ²T, N}
x′ = reshape(x, :, size(x, N - 1), size(x, N))
return reshape(
batchnorm_affine_normalize_internal(opmode, act, x′, vec(μ), vec(σ²), γ, β, ϵ),
Expand All @@ -64,7 +64,7 @@ end
@stable default_mode="disable" function batchnorm_affine_normalize_internal(
opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, 3},
μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector},
β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT}
β::Optional{<:AbstractVector}, ϵ) where {F, xT}
y = similar(x,
promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²),
safe_eltype(γ), safe_eltype(β)))
Expand All @@ -75,7 +75,7 @@ end
function batchnorm_affine_normalize_internal!(
y::AbstractArray{yT, 3}, opmode::LoopedArrayOp, act::F, x::AbstractArray{xT, 3},
μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector},
β::Optional{<:AbstractVector}, ϵ::Real,
β::Optional{<:AbstractVector}, ϵ,
γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT}
N = size(y, 2)
γ′ = γ′ === nothing ?
Expand Down Expand Up @@ -225,7 +225,7 @@ end
function batchnorm_affine_normalize_internal!(
y::AbstractArray{yT, 3}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 3},
μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector},
β::Optional{<:AbstractVector}, ϵ::Real,
β::Optional{<:AbstractVector}, ϵ,
γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT}
backend = KA.get_backend(y)
run_ka_kernel(
Expand Down Expand Up @@ -278,7 +278,7 @@ function CRC.rrule(
cfg::RuleConfig{>:HasReverseMode}, ::typeof(batchnorm_affine_normalize_internal),
opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{T, N},
μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector},
β::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N}
β::Optional{<:AbstractVector}, ϵ) where {F, T, N}
y = similar(x,
promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²),
safe_eltype(γ), safe_eltype(β)))
Expand All @@ -304,7 +304,7 @@ end

function ∇batchnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArray{∂yT, 3},
x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector,
γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real,
γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ,
γ′::AbstractVector) where {∂yT, xT}
∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²)
∂γ = γ === nothing ? nothing : similar(γ)
Expand All @@ -322,7 +322,7 @@ function ∇batchnorm_affine_normalize_cpu!(
∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT},
∂σ²::AbstractVector{∂σ²T}, ::Nothing, ::Nothing, ∂y::AbstractArray{∂yT, 3},
x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, ::Nothing,
ϵ::Real, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT}
ϵ, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT}
half = eltype(∂σ²)(0.5)

fill!(∂μ, 0)
Expand Down Expand Up @@ -361,7 +361,7 @@ function ∇batchnorm_affine_normalize_cpu!(
∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT},
∂σ²::AbstractVector{∂σ²T}, ∂γ::AbstractVector{∂γT},
∂β::AbstractVector{∂βT}, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3},
μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ::Real,
μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ,
γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT}
half = eltype(∂σ²)(0.5)

Expand Down Expand Up @@ -406,7 +406,7 @@ end
function ∇batchnorm_affine_normalize(
opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{∂yT, 3},
x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector,
γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real,
γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ,
γ′::AbstractVector) where {∂yT, xT}
∂x, ∂σ² = similar(x), similar(σ², size(x))
∂γ = γ === nothing ? nothing : similar(γ, size(x))
Expand All @@ -425,7 +425,7 @@ function ∇batchnorm_affine_normalize!(
∂x::AbstractArray{∂xT, 3}, ∂σ²::AbstractArray{∂σ²T, 3},
∂γ::Optional{<:AbstractArray{<:Any, 3}}, ::GPUBroadcastOp,
∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector,
σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real,
σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ,
γ′::AbstractVector) where {∂xT, ∂σ²T, ∂yT, xT}
backend = KA.get_backend(∂x)
run_ka_kernel(
Expand Down
14 changes: 7 additions & 7 deletions lib/LuxLib/src/impl/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,22 @@ function alpha_dropout(noise::AbstractArray, p, x::AbstractArray, α, A, B)
end

@stable default_mode="disable" function alpha_dropout(
::AbstractInternalArrayOpMode, noise::AbstractArray, p::Real,
x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T}
::AbstractInternalArrayOpMode, noise::AbstractArray, p,
x::AbstractArray{T}, α, A, B) where {T}
A′, B′, α = T(A), T(B), T(α)
return @. muladd(ifelse(noise > p, x, α), A′, B′)
end

@stable default_mode="disable" function alpha_dropout(
opmode::LoopedArrayOp, noise::AbstractArray, p::Real,
x::AbstractArray, α::Real, A::Real, B::Real)
opmode::LoopedArrayOp, noise::AbstractArray, p,
x::AbstractArray, α, A, B)
res = similar(x, promote_type(typeof(p), typeof(α)))
alpha_dropout!(res, opmode, noise, p, x, α, A, B)
return res
end

function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArray,
p::Real, x::AbstractArray, α::Real, A::Real, B::Real)
p, x::AbstractArray, α, A, B)
cond = similar(noise, Bool)
y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x)))
@simd ivdep for I in eachindex(noise, x, y, cond)
Expand All @@ -99,7 +99,7 @@ function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArra
end

function CRC.rrule(::typeof(alpha_dropout), ::AbstractInternalArrayOpMode,
noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real)
noise::AbstractArray, p, x::AbstractArray, α, A, B)
cond = noise .> p
y = @. ifelse(cond, x, α) * A + B

Expand All @@ -114,7 +114,7 @@ end

function alpha_dropout!(
res::AbstractArray{T}, ::LoopedArrayOp, noise::AbstractArray{T},
p::Real, x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T}
p, x::AbstractArray{T}, α, A, B) where {T}
@simd ivdep for I in eachindex(noise, x, res)
res[I] = ifelse(noise[I] > p, x[I], α) * A + B
end
Expand Down
Loading

0 comments on commit 896675c

Please sign in to comment.