Skip to content

Commit

Permalink
fix: use return_type instead of _return_type (#1148)
Browse files Browse the repository at this point in the history
  • Loading branch information
simeonschaub authored Dec 27, 2024
1 parent bcbcbce commit 90997a0
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion lib/LuxLib/src/impl/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function activation(::AbstractInternalArrayOpMode, σ::F, x::AbstractArray) wher
end
@stable default_mode="disable" function activation(
opmode::LoopedArrayOp, σ::F, x::AbstractArray{T}) where {F, T}
RT = Core.Compiler._return_type(σ, Tuple{T})
RT = Core.Compiler.return_type(σ, Tuple{T})
y = similar(x, ifelse(isconcretetype(RT), RT, T))
activation!(y, opmode, σ, x)
return y
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end

function batched_matmul(opmode::GPUBroadcastOp{<:AbstractGPUDevice},
x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT}
if isconcretetype(Core.Compiler._return_type(
if isconcretetype(Core.Compiler.return_type(
NNlib.batched_mul, Tuple{typeof(x), typeof(y)}))
return NNlib.batched_mul(x, y) # GPU versions are well optimized
end
Expand Down
4 changes: 2 additions & 2 deletions lib/LuxLib/src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ end
activation_intermediate_not_needed(::typeof(identity), ::Type) = True()

function activation_intermediate_not_needed(::F, ::Type{T}) where {F, T}
return static(isconcretetype(Core.Compiler._return_type(
return static(isconcretetype(Core.Compiler.return_type(
only_derivative, Tuple{T, F, NotaNumber})))
end

function activation_has_rrule(::F, ::Type{T}) where {F, T}
return static(isconcretetype(Core.Compiler._return_type(
return static(isconcretetype(Core.Compiler.return_type(
only_derivative, Tuple{T, F, T})))
end

Expand Down
4 changes: 2 additions & 2 deletions lib/LuxLib/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ safe_vec(::Nothing) = nothing

## This part is taken from NNlib.jl
# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)`
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
# is independent of `x`, as `return_type` says `Union{}` when calling is an error.
struct NotaNumber <: Real end

# This just saves typing `only.(only.(` many times:
Expand Down Expand Up @@ -118,7 +118,7 @@ CRC.@non_differentiable default_epsilon(::Any...)
function concrete_bias_act_output_eltype(act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx},
b::Optional{<:AbstractVector}) where {F, Tw, Tx}
Ty = promote_type(Tw, Tx, safe_eltype(b))
Tact = Core.Compiler._return_type(act, Tuple{Ty})
Tact = Core.Compiler.return_type(act, Tuple{Ty})
return ifelse(isconcretetype(Tact), Tact, Ty)
end

Expand Down
4 changes: 2 additions & 2 deletions src/helpers/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ end
fused_agg(::typeof(sum), op::OP, x::Number, y::Number) where {OP} = op(x, y)
function fused_agg(::typeof(sum), op::OP, x::AbstractArray, y::AbstractArray) where {OP}
if fast_scalar_indexing(x) && fast_scalar_indexing(y)
res = Core.Compiler._return_type(op, Tuple{eltype(x), eltype(y)})(0)
res = Core.Compiler.return_type(op, Tuple{eltype(x), eltype(y)})(0)
@simd ivdep for i in eachindex(x, y)
@inbounds res += op(x[i], y[i])
end
Expand Down Expand Up @@ -73,7 +73,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
Nothing, eltype(x), 1}.(x, (Partials{1, eltype(x)}((one(eltype(x)),)),))
x_partials = similar(x)
T = eltype(x)
res = Core.Compiler._return_type(op, Tuple{T, eltype(y)})(0)
res = Core.Compiler.return_type(op, Tuple{T, eltype(y)})(0)
@inbounds @simd for i in eachindex(x_partials, x, y)
x_dual = Dual{Nothing, T, 1}(x[i], Partials{1, T}((one(T),)))
tmp = op(x_dual, y[i])
Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ unbatched_structure(x) = fmapstructure(size_unbatched, x)
can_named_tuple(::NamedTuple) = true
can_named_tuple(::T) where {T} = can_named_tuple(T)
function can_named_tuple(::Type{T}) where {T}
return Core.Compiler._return_type(named_tuple, Tuple{T}) !== Union{}
return Core.Compiler.return_type(named_tuple, Tuple{T}) !== Union{}
end

@non_differentiable can_named_tuple(::Any)

# Convert to a NamedTuple
named_tuple(nt::NamedTuple) = nt
function named_tuple(x::T) where {T}
NT = Core.Compiler._return_type(NamedTuple, Tuple{T})
NT = Core.Compiler.return_type(NamedTuple, Tuple{T})
if NT === Union{} || NT === NamedTuple
error("`NamedTuple` is not defined for type `$(T)`. Please define \
`Lux.Utils.named_tuple(::$(T))` method (or preferably \
Expand Down

0 comments on commit 90997a0

Please sign in to comment.