Skip to content

Commit

Permalink
Preserve the buffer type when broadcasting. (#383)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Dec 19, 2023
1 parent 3b97437 commit 41b30de
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 38 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ oneAPI_Level_Zero_Loader_jll = "13eca655-d68d-5b81-8367-6d99d727ab01"
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"

[compat]
Adapt = "2.0, 3.0"
Adapt = "4"
CEnum = "0.4, 0.5"
ExprTools = "0.1"
GPUArrays = "9"
GPUArrays = "10"
GPUCompiler = "0.23, 0.24, 0.25"
KernelAbstractions = "0.9.1"
LLVM = "6"
Expand Down
86 changes: 72 additions & 14 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export oneArray, oneVector, oneMatrix, oneVecOrMat
export oneArray, oneVector, oneMatrix, oneVecOrMat,
is_device, is_shared, is_host


## array type
Expand Down Expand Up @@ -168,6 +169,12 @@ function device(A::oneArray)
return oneL0.device(A.data[])
end

buftype(x::oneArray) = buftype(typeof(x))
buftype(::Type{<:oneArray{<:Any,<:Any,B}}) where {B} = @isdefined(B) ? B : Any

is_device(a::oneArray) = isa(a.data[], oneL0.DeviceBuffer)
is_shared(a::oneArray) = isa(a.data[], oneL0.SharedBuffer)
is_host(a::oneArray) = isa(a.data[], oneL0.HostBuffer)

## derived types

Expand Down Expand Up @@ -195,9 +202,15 @@ const oneStridedVector{T} = oneStridedArray{T,1}
const oneStridedMatrix{T} = oneStridedArray{T,2}
const oneStridedVecOrMat{T} = Union{oneStridedVector{T}, oneStridedMatrix{T}}

Base.pointer(x::oneStridedArray{T}) where {T} = Base.unsafe_convert(ZePtr{T}, x)
@inline function Base.pointer(x::oneStridedArray{T}, i::Integer) where T
Base.unsafe_convert(ZePtr{T}, x) + Base._memory_offset(x, i)
@inline function Base.pointer(x::oneStridedArray{T}, i::Integer=1; type=oneL0.DeviceBuffer) where T
PT = if type == oneL0.DeviceBuffer
ZePtr{T}
elseif type == oneL0.HostBuffer
Ptr{T}
else
error("unknown memory type")
end
Base.unsafe_convert(PT, x) + Base._memory_offset(x, i)
end

# anything that's (secretly) backed by a oneArray
Expand Down Expand Up @@ -241,12 +254,20 @@ oneL0.ZeRef{T}() where {T} = oneL0.ZeRefArray(oneArray{T}(undef, 1))
Base.convert(::Type{T}, x::T) where T <: oneArray = x


## interop with C libraries
## interop with libraries

function Base.unsafe_convert(::Type{Ptr{T}}, x::oneArray{T}) where {T}
buf = x.data[]
if is_device(x)
throw(ArgumentError("cannot take the CPU address of a $(typeof(x))"))
end
convert(Ptr{T}, x.data[]) + x.offset*Base.elsize(x)
end

Base.unsafe_convert(::Type{Ptr{T}}, x::oneArray{T}) where {T} =
throw(ArgumentError("cannot take the host address of a $(typeof(x))"))
Base.unsafe_convert(::Type{ZePtr{T}}, x::oneArray{T}) where {T} =
function Base.unsafe_convert(::Type{ZePtr{T}}, x::oneArray{T}) where {T}
convert(ZePtr{T}, x.data[]) + x.offset*Base.elsize(x)
end



## interop with GPU arrays
Expand All @@ -256,9 +277,6 @@ function Base.unsafe_convert(::Type{oneDeviceArray{T,N,AS.Global}}, a::oneArray{
a.maxsize - a.offset*Base.elsize(a))
end

Adapt.adapt_storage(::KernelAdaptor, xs::oneArray{T,N}) where {T,N} =
Base.unsafe_convert(oneDeviceArray{T,N,AS.Global}, xs)


## memory copying

Expand Down Expand Up @@ -310,7 +328,7 @@ Base.copyto!(dest::oneDenseArray{T}, src::oneDenseArray{T}) where {T} =
copyto!(dest, 1, src, 1, length(src))

function Base.unsafe_copyto!(ctx::ZeContext, dev::ZeDevice,
dest::oneDenseArray{T}, doffs, src::Array{T}, soffs, n) where T
dest::oneDenseArray{T,<:Any,oneL0.DeviceBuffer}, doffs, src::Array{T}, soffs, n) where T
GC.@preserve src dest unsafe_copyto!(ctx, dev, pointer(dest, doffs), pointer(src, soffs), n)
if Base.isbitsunion(T)
# copy selector bytes
Expand All @@ -320,7 +338,7 @@ function Base.unsafe_copyto!(ctx::ZeContext, dev::ZeDevice,
end

function Base.unsafe_copyto!(ctx::ZeContext, dev::ZeDevice,
dest::Array{T}, doffs, src::oneDenseArray{T}, soffs, n) where T
dest::Array{T}, doffs, src::oneDenseArray{T,<:Any,oneL0.DeviceBuffer}, soffs, n) where T
GC.@preserve src dest unsafe_copyto!(ctx, dev, pointer(dest, doffs), pointer(src, soffs), n)
if Base.isbitsunion(T)
# copy selector bytes
Expand All @@ -343,6 +361,46 @@ function Base.unsafe_copyto!(ctx::ZeContext, dev::ZeDevice,
return dest
end

# between Array and host-accessible oneArray

function Base.unsafe_copyto!(ctx::ZeContext, dev,
dest::oneDenseArray{T,<:Any,<:Union{oneL0.SharedBuffer,oneL0.HostBuffer}}, doffs, src::Array{T}, soffs, n) where T
if Base.isbitsunion(T)
# copy selector bytes
error("oneArray does not yet support isbits-union arrays")
end
# XXX: maintain queue-ordered semantics? HostBuffers don't have a device...
GC.@preserve src dest begin
ptr = pointer(dest, doffs)
unsafe_copyto!(pointer(dest, doffs; type=oneL0.HostBuffer), pointer(src, soffs), n)
if Base.isbitsunion(T)
# copy selector bytes
error("oneArray does not yet support isbits-union arrays")
end
end

return dest
end

function Base.unsafe_copyto!(ctx::ZeContext, dev,
dest::Array{T}, doffs, src::oneDenseArray{T,<:Any,<:Union{oneL0.SharedBuffer,oneL0.HostBuffer}}, soffs, n) where T
if Base.isbitsunion(T)
# copy selector bytes
error("oneArray does not yet support isbits-union arrays")
end
# XXX: maintain queue-ordered semantics? HostBuffers don't have a device...
GC.@preserve src dest begin
ptr = pointer(dest, doffs)
unsafe_copyto!(pointer(dest, doffs), pointer(src, soffs; type=oneL0.HostBuffer), n)
if Base.isbitsunion(T)
# copy selector bytes
error("oneArray does not yet support isbits-union arrays")
end
end

return dest
end


## gpu array adaptor

Expand Down Expand Up @@ -375,7 +433,7 @@ end

## derived arrays

function GPUArrays.derive(::Type{T}, N::Int, a::oneArray, dims::Dims, offset::Int) where {T}
function GPUArrays.derive(::Type{T}, a::oneArray, dims::Dims{N}, offset::Int) where {T,N}
offset = (a.offset * Base.elsize(a)) ÷ sizeof(T) + offset
oneArray{T,N}(a.data, dims; a.maxsize, offset)
end
Expand Down
26 changes: 14 additions & 12 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@

using Base.Broadcast: BroadcastStyle, Broadcasted

struct oneArrayStyle{N} <: AbstractGPUArrayStyle{N} end
oneArrayStyle(::Val{N}) where N = oneArrayStyle{N}()
oneArrayStyle{M}(::Val{N}) where {N,M} = oneArrayStyle{N}()
struct oneArrayStyle{N,B} <: AbstractGPUArrayStyle{N} end
oneArrayStyle{M,B}(::Val{N}) where {N,M,B} = oneArrayStyle{N,B}()

BroadcastStyle(::Type{<:oneArray{T,N}}) where {T,N} = oneArrayStyle{N}()
# identify the broadcast style of a (wrapped) oneArray
BroadcastStyle(::Type{<:oneArray{T,N,B}}) where {T,N,B} = oneArrayStyle{N,B}()
BroadcastStyle(W::Type{<:oneWrappedArray{T,N}}) where {T,N} =
oneArrayStyle{N, buftype(Adapt.unwrap_type(W))}()

Base.similar(bc::Broadcasted{oneArrayStyle{N}}, ::Type{T}) where {N,T} =
similar(oneArray{T}, axes(bc))
# when we are dealing with different buffer styles, we cannot know
# which one is better, so use shared memory
BroadcastStyle(::oneArrayStyle{N, B1},
::oneArrayStyle{N, B2}) where {N,B1,B2} =
oneArrayStyle{N, oneL0.SharedBuffer}()

Base.similar(bc::Broadcasted{oneArrayStyle{N}}, ::Type{T}, dims...) where {N,T} =
oneArray{T}(undef, dims...)

# broadcasting type ctors isn't GPU compatible
Broadcast.broadcasted(::oneArrayStyle{N}, f::Type{T}, args...) where {N, T} =
Broadcasted{oneArrayStyle{N}}((x...) -> T(x...), args, nothing)
# allocation of output arrays
Base.similar(bc::Broadcasted{oneArrayStyle{N,B}}, ::Type{T}, dims) where {T,N,B} =
similar(oneArray{T,length(dims),B}, dims)
14 changes: 12 additions & 2 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,15 @@ end

struct KernelAdaptor end

# convert oneL0 host pointers to device pointers
# convert oneAPI host pointers to device pointers
Adapt.adapt_storage(to::KernelAdaptor, p::ZePtr{T}) where {T} = reinterpret(Ptr{T}, p)

# Base.RefValue isn't GPU compatible, so provide a compatible alternative
# convert oneAPI host arrays to device arrays
Adapt.adapt_storage(::KernelAdaptor, xs::oneArray{T,N}) where {T,N} =
Base.unsafe_convert(oneDeviceArray{T,N,AS.Global}, xs)

# Base.RefValue isn't GPU compatible, so provide a compatible alternative.
# TODO: port improvements from CUDA.jl
struct ZeRefValue{T} <: Ref{T}
x::T
end
Expand All @@ -100,6 +105,11 @@ Base.getindex(r::oneRefType{T}) where T = T
Adapt.adapt_structure(to::KernelAdaptor, r::Base.RefValue{<:Union{DataType,Type}}) =
oneRefType{r[]}()

# case where type is the function being broadcasted
Adapt.adapt_structure(to::KernelAdaptor,
bc::Broadcast.Broadcasted{Style, <:Any, Type{T}}) where {Style, T} =
Broadcast.Broadcasted{Style}((x...) -> T(x...), adapt(to, bc.args), bc.axes)

"""
kernel_convert(x)
Expand Down
10 changes: 5 additions & 5 deletions src/oneAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ include("device/quirks.jl")
# essential stuff
include("context.jl")

# compiler implementation
include("compiler/compilation.jl")
include("compiler/execution.jl")
include("compiler/reflection.jl")

# array abstraction
include("memory.jl")
include("pool.jl")
include("array.jl")

# compiler implementation
include("compiler/compilation.jl")
include("compiler/execution.jl")
include("compiler/reflection.jl")

# array libraries
include("../lib/mkl/oneMKL.jl")
export oneMKL
Expand Down
13 changes: 10 additions & 3 deletions src/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@ function allocate(::Type{oneL0.SharedBuffer}, ctx, dev, bytes::Int, alignment::I
return buf
end

function allocate(::Type{oneL0.HostBuffer}, ctx, dev, bytes::Int, alignment::Int)
bytes == 0 && return oneL0.HostBuffer(ZE_NULL, bytes, ctx)
host_alloc(ctx, bytes, alignment)
end

function release(buf::oneL0.AbstractBuffer)
sizeof(buf) == 0 && return

ctx = oneL0.context(buf)
dev = oneL0.device(buf)
if buf isa oneL0.DeviceBuffer || buf isa oneL0.SharedBuffer
ctx = oneL0.context(buf)
dev = oneL0.device(buf)
evict(ctx, dev, buf)
end

evict(ctx, dev, buf)
free(buf; policy=oneL0.ZE_DRIVER_MEMORY_FREE_POLICY_EXT_FLAG_BLOCKING_FREE)

# TODO: queue-ordered free from non-finalizer tasks once we have
Expand Down
16 changes: 16 additions & 0 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,19 @@ end
oneAPI.@sync copyto!(a, 2, [200], 1, 1)
@test b == [100, 200]
end

# https://github.com/JuliaGPU/CUDA.jl/issues/2191
@testset "preserving buffer types" begin
a = oneVector{Int,oneL0.SharedBuffer}([1])
@test oneAPI.buftype(a) == oneL0.SharedBuffer

# unified-ness should be preserved
b = a .+ 1
@test oneAPI.buftype(b) == oneL0.SharedBuffer

# when there's a conflict, we should defer to unified memory
c = oneVector{Int,oneL0.HostBuffer}([1])
d = oneVector{Int,oneL0.DeviceBuffer}([1])
e = c .+ d
@test oneAPI.buftype(e) == oneL0.SharedBuffer
end

0 comments on commit 41b30de

Please sign in to comment.