Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: add unsafe_free!
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 28, 2024
1 parent 1c2619a commit 50df4fa
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 0 deletions.
6 changes: 6 additions & 0 deletions ext/MLDataDevicesAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ function MLDataDevices.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Intege
return MLDataDevices.set_device!(AMDGPUDevice, id)
end

# unsafe_free!
function Internal.unsafe_free_internal!(::Type{AMDGPUDevice}, x::AbstractArray)
AMDGPU.unsafe_free!(x)
return
end

# Device Transfer
## To GPU
Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x)
Expand Down
6 changes: 6 additions & 0 deletions ext/MLDataDevicesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ function MLDataDevices.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer)
return MLDataDevices.set_device!(CUDADevice, id)
end

# unsafe_free!
function Internal.unsafe_free_internal!(::Type{CUDADevice}, x::AbstractArray)
CUDA.unsafe_free!(x)
return
end

# Device Transfer
Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x)
function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray)
Expand Down
5 changes: 5 additions & 0 deletions ext/MLDataDevicesMLUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module MLDataDevicesMLUtilsExt

using MLUtils: DataLoader

end
6 changes: 6 additions & 0 deletions ext/MLDataDevicesMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ Internal.get_device(::MtlArray) = MetalDevice()

Internal.get_device_type(::MtlArray) = MetalDevice

# unsafe_free!
function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray)
Metal.unsafe_free!(x)
return
end

# Device Transfer
## To GPU
Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x)
Expand Down
6 changes: 6 additions & 0 deletions ext/MLDataDevicesoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ Internal.get_device(::oneArray) = oneAPIDevice()

Internal.get_device_type(::oneArray) = oneAPIDevice

# unsafe_free!
function Internal.unsafe_free_internal!(::Type{oneAPIDevice}, x::AbstractArray)
oneAPI.unsafe_free!(x)
return
end

# Device Transfer
## To GPU
for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32))
Expand Down
13 changes: 13 additions & 0 deletions src/internal.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Internal

using Functors: fmap
using Preferences: load_preference
using Random: AbstractRNG
using UnrolledUtilities: unrolled_mapreduce
Expand Down Expand Up @@ -149,4 +150,16 @@ for op in (:get_device, :get_device_type)
end
end

function unsafe_free_internal!(x::AbstractArray)
unsafe_free_internal!(MLDataDevices.get_device_type(x), x)
return
end
unsafe_free_internal!(::Type, x::AbstractArray) = nothing
unsafe_free_internal!(_) = nothing

function unsafe_free!(x)
fmap(unsafe_free_internal!, x)
return
end

end

0 comments on commit 50df4fa

Please sign in to comment.