diff --git a/lib/level-zero/device.jl b/lib/level-zero/device.jl index 4e814f75..1b145812 100644 --- a/lib/level-zero/device.jl +++ b/lib/level-zero/device.jl @@ -1,6 +1,6 @@ export ZeDevice, properties, compute_properties, module_properties, memory_properties, memory_access_properties, cache_properties, image_properties, p2p_properties -struct ZeDevice +struct ZeDevice <: Adapt.AbstractGPUDevice handle::ze_device_handle_t driver::ZeDriver diff --git a/lib/level-zero/oneL0.jl b/lib/level-zero/oneL0.jl index b2f64733..90e2f565 100644 --- a/lib/level-zero/oneL0.jl +++ b/lib/level-zero/oneL0.jl @@ -9,6 +9,9 @@ using Printf using NEO_jll using oneAPI_Level_Zero_Loader_jll +import Adapt + + include("utils.jl") include("pointer.jl") diff --git a/src/array.jl b/src/array.jl index daa1c665..81a4e9c6 100644 --- a/src/array.jl +++ b/src/array.jl @@ -159,6 +159,8 @@ function device(A::oneArray) return oneL0.device(A.storage.buffer) end +Adapt.get_compute_unit_impl(@nospecialize(TypeHistory::Type), A::oneArray) = device(A) + ## derived types diff --git a/src/context.jl b/src/context.jl index 80ab83be..b34d9ad8 100644 --- a/src/context.jl +++ b/src/context.jl @@ -55,3 +55,18 @@ end function oneL0.synchronize() oneL0.synchronize(global_queue(context(), device())) end + + +function Adapt.adapt_storage(dev::ZeDevice, x) + prev_dev = device() + try + device!(dev) + Adapt.adapt_storage(oneArray, x) + finally + device!(prev_dev) + end +end + + +# ToDo: implement Sys.total_memory(dev::ZeDevice) +# ToDo: implement Sys.free_memory(dev::ZeDevice)