diff --git a/Project.toml b/Project.toml index eedc493..b4e5434 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -58,7 +57,6 @@ RecursiveArrayTools = "3.8" ReverseDiff = "1.15" SparseArrays = "1.10" Tracker = "0.2.34" -UnrolledUtilities = "0.1.2" Zygote = "0.6.69" cuDNN = "1.3" julia = "1.10" diff --git a/src/internal.jl b/src/internal.jl index f2c807e..8277f7c 100644 --- a/src/internal.jl +++ b/src/internal.jl @@ -3,7 +3,6 @@ module Internal using Functors: fmap using Preferences: load_preference using Random: AbstractRNG -using UnrolledUtilities: unrolled_mapreduce using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, supported_gpu_backends, GPU_DEVICES, @@ -150,6 +149,34 @@ for op in (:get_device, :get_device_type) end end +function unrolled_mapreduce(f::F, op::O, itr) where {F, O} + return unrolled_mapreduce(f, op, itr, static_length(itr)) +end + +function unrolled_mapreduce(::F, ::O, _, ::Val{0}) where {F, O} + error("Cannot unroll over an empty iterator.") +end + +unrolled_mapreduce(f::F, ::O, itr, ::Val{1}) where {F, O} = f(only(itr)) + +@generated function unrolled_mapreduce(f::F, op::O, itr, ::Val{N}) where {F, O, N} + syms = [gensym("f_itr_$(i)") for i in 1:N] + op_syms = [gensym("op_$(i)") for i in 1:(N - 1)] + f_applied = [:($(syms[i]) = f(itr[$i])) for i in 1:N] + combine_expr = [:($(op_syms[1]) = op($(syms[1]), $(syms[2])))] + for i in 2:(N - 1) + push!(combine_expr, :($(op_syms[i]) = op($(op_syms[i - 1]), $(syms[i + 1])))) + end + return quote + $(Expr(:meta, :inline)) + $(Expr(:inbounds, true)) + $(Expr(:block, f_applied...)) + $(Expr(:inbounds, :pop)) + $(Expr(:block, combine_expr...)) + return $(op_syms[end]) + end +end + function unsafe_free_internal!(x::AbstractArray) unsafe_free_internal!(MLDataDevices.get_device_type(x), x) return @@ -162,4 +189,6 @@ function unsafe_free!(x) return end +static_length(t::Tuple) = Val(length(t)) + end