Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: remove UnrolledUtilities dep
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 11, 2024
1 parent 7361845 commit 2787e22
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand Down Expand Up @@ -58,7 +57,6 @@ RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Tracker = "0.2.34"
UnrolledUtilities = "0.1.2"
Zygote = "0.6.69"
cuDNN = "1.3"
julia = "1.10"
Expand Down
25 changes: 24 additions & 1 deletion src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module Internal
using Functors: fmap
using Preferences: load_preference
using Random: AbstractRNG
using UnrolledUtilities: unrolled_mapreduce

using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, supported_gpu_backends, GPU_DEVICES,
Expand Down Expand Up @@ -150,6 +149,28 @@ for op in (:get_device, :get_device_type)
end
end

function unrolled_mapreduce(f::F, op::O, itr) where {F, O}
return unrolled_mapreduce(f, op, itr, static_length(itr))
end

@generated function unrolled_mapreduce(f::F, op::O, itr, ::Val{N}) where {F, O, N}
syms = [gensym("f_itr_$(i)") for i in 1:N]
op_syms = [gensym("op_$(i)") for i in 1:(N-1)]
f_applied = [:($(syms[i]) = f(itr[$i])) for i in 1:N]
combine_expr = [:($(op_syms[1]) = op($(syms[1]), $(syms[2])))]
for i in 2:(N - 1)
push!(combine_expr, :($(op_syms[i]) = op($(op_syms[i - 1]), $(syms[i + 1]))))
end
return quote
$(Expr(:meta, :inline))
$(Expr(:inbounds, true))
$(Expr(:block, f_applied...))
$(Expr(:inbounds, :pop))
$(Expr(:block, combine_expr...))
return $(op_syms[end])
end
end

function unsafe_free_internal!(x::AbstractArray)
unsafe_free_internal!(MLDataDevices.get_device_type(x), x)
return
Expand All @@ -162,4 +183,6 @@ function unsafe_free!(x)
return
end

static_length(t::Tuple) = Val(length(t))

end

0 comments on commit 2787e22

Please sign in to comment.