Skip to content

Commit

Permalink
Allow bulk-freeing arrays instead of caching them
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Dec 17, 2024
1 parent dd3af70 commit 40e5447
Showing 1 changed file with 44 additions and 29 deletions.
73 changes: 44 additions & 29 deletions src/host/allocations_cache.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using Base.ScopedValues

const CacheAllocatorName = ScopedValue(:none)

struct CacheAllocator{T <: AbstractGPUArray}
lock::ReentrantLock
busy::Dict{UInt64, Vector{T}} # hash((T, dims)) => GPUArray[]
Expand All @@ -21,38 +23,42 @@ function get_pool!(cache::CacheAllocator{T}, pool::Symbol, uid::UInt64) where T
return uid_pool
end

function alloc!(alloc_f, cache::CacheAllocator, ::Type{T}, dims::Dims{N}) where {T, N}
function alloc!(alloc_f, cache::CacheAllocator, ::Type{T}, dims::Dims{N}; skip_free::Bool) where {T, N}
x = nothing
uid = hash((T, dims))
free_pool = get_pool!(cache, :free, uid)
busy_pool = get_pool!(cache, :busy, uid)

x = nothing

# No array available in `free` - call `alloc_f`.
isempty(free_pool) && (x = alloc_f())
if skip_free
x = alloc_f()
else
free_pool = get_pool!(cache, :free, uid)
isempty(free_pool) && (x = alloc_f())

# Otherwise, try fetching from `free`.
while !isempty(free_pool) && x nothing
tmp = pop!(free_pool)
# Array was manually freed via `unsafe_free!`.
storage(tmp).freed && continue
x = tmp
while !isempty(free_pool) && x nothing
tmp = Base.@lock cache.lock pop!(free_pool)
# Array was manually freed via `unsafe_free!`.
storage(tmp).freed && continue
x = tmp
end
end

# No array in cache - call `alloc_f`.
x nothing && (x = alloc_f())
push!(busy_pool, x)
Base.@lock cache.lock push!(busy_pool, x)
return x
end

function free_busy!(cache::CacheAllocator)
function free_busy!(cache::CacheAllocator; free_immediately::Bool)
for uid in cache.busy.keys
busy_pool = get_pool!(cache, :busy, uid)
isempty(busy_pool) && continue

free_pool = get_pool!(cache, :free, uid)
Base.@lock cache.lock begin
append!(free_pool, busy_pool)
if free_immediately
for p in busy_pool unsafe_free!(p) end
else
append!(free_pool, busy_pool)
end
empty!(busy_pool)
end
end
Expand All @@ -61,10 +67,11 @@ end
struct PerDeviceCacheAllocator{T <: AbstractGPUArray}
lock::ReentrantLock
caches::Dict{UInt64, Dict{Symbol, CacheAllocator{T}}}
free_immediately::Bool
end

PerDeviceCacheAllocator(::Type{T}) where T <: AbstractGPUArray =
PerDeviceCacheAllocator(ReentrantLock(), Dict{UInt64, Dict{Symbol, CacheAllocator{T}}}())
PerDeviceCacheAllocator(::Type{T}; free_immediately::Bool) where T <: AbstractGPUArray =
PerDeviceCacheAllocator(ReentrantLock(), Dict{UInt64, Dict{Symbol, CacheAllocator{T}}}(), free_immediately)

function named_cache_allocator!(pdcache::PerDeviceCacheAllocator{T}, device, name::Symbol) where T
h = hash(device)
Expand All @@ -85,6 +92,12 @@ function named_cache_allocator!(pdcache::PerDeviceCacheAllocator{T}, device, nam
return named_cache
end

function alloc!(alloc_f, kab::Backend, name::Symbol, ::Type{T}, dims::Dims{N}) where {T, N}
pdcache = cache_allocator(kab)
cache = named_cache_allocator!(pdcache, device(kab), name)
alloc!(alloc_f, cache, T, dims; skip_free=pdcache.free_immediately)
end

function Base.sizeof(pdcache::PerDeviceCacheAllocator, device, name::Symbol)
sz = UInt64(0)
h = hash(device)
Expand All @@ -106,6 +119,9 @@ function Base.sizeof(pdcache::PerDeviceCacheAllocator, device, name::Symbol)
return sz
end

invalidate_cache_allocator!(kab::Backend, name::Symbol) =
invalidate_cache_allocator!(cache_allocator(kab), device(kab), name)

function invalidate_cache_allocator!(pdcache::PerDeviceCacheAllocator, device, name::Symbol)
h = hash(device)
dev_cache = get(pdcache.caches, h, nothing)
Expand All @@ -128,28 +144,27 @@ function invalidate_cache_allocator!(pdcache::PerDeviceCacheAllocator, device, n
return
end

function free_busy!(kab::Backend, name::Symbol)
pdcache = cache_allocator(kab)
free_busy!(named_cache_allocator!(pdcache, device(kab), name); pdcache.free_immediately)
end

macro cache_scope(backend, name, expr)
quote
scope = cache_alloc_scope($(esc(backend)))
res = @with scope => $(esc(name)) $(esc(expr))
free_busy_cache_alloc!(cache_allocator($(esc(backend))), $(esc(name)))
res = @with $(esc(CacheAllocatorName)) => $(esc(name)) $(esc(expr))
free_busy!($(esc(backend)), $(esc(name)))
res
end
end

macro no_cache_scope(backend, expr)
macro no_cache_scope(expr)
quote
scope = cache_alloc_scope($(esc(backend)))
@with scope => :none $(esc(expr))
@with $(esc(CacheAllocatorName)) => :none $(esc(expr))
end
end

# Interface API.

cache_alloc_scope(::Backend) = error("Not implemented.")

cache_allocator(::Backend) = error("Not implemented.")

free_busy_cache_alloc!(pdcache, name::Symbol) = error("Not implemented.")

invalidate_cache_allocator!(pdcache, name::Symbol) = error("Not implemented.")
device(::Backend) = error("Not implemented.")

0 comments on commit 40e5447

Please sign in to comment.