From 76d6d06aef79b6995cbbb77a04aa70a2ab651a98 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 1 Nov 2023 18:50:52 +0100 Subject: [PATCH] Improve support for unified and host memory (#2138) - scalar indexing is now allowed with unified and host arrays - construction using `cu` has been generalized, now taking both `device`, `host` and `unified` boolean kwargs - the default memory location can be configured using the `default_memory `preference (now reported by `versioninfo`) - `unsafe_wrap` has been extended to take array inputs - HMM support, allow conversion of unmanaged CPU memory to `CuArray` objects - CI coverage for all of the above --- .buildkite/pipeline.yml | 57 ++++++--- LocalPreferences.toml | 4 + lib/cusparse/array.jl | 14 +- src/array.jl | 262 +++++++++++++++++++++++++++++++------- src/compiler/execution.jl | 39 +++--- src/pool.jl | 6 + src/texture.jl | 2 +- src/utilities.jl | 17 +++ test/base/array.jl | 167 +++++++++++++----------- test/core/execution.jl | 2 +- test/runtests.jl | 4 + 11 files changed, 408 insertions(+), 166 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index edd315982c..a23e01d9d0 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -19,8 +19,7 @@ steps: cuda: "*" commands: | echo -e "[CUDA_Runtime_jll]\nlocal = \"true\"" >LocalPreferences.toml - if: build.message !~ /\[skip tests\]/ && - build.message !~ /\[skip julia\]/ + if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 120 matrix: setup: @@ -44,7 +43,7 @@ steps: - JuliaCI/julia#v1: version: 1.9 - JuliaCI/julia-test#v1: - test_args: "core base libraries" + test_args: "--quickfail core base libraries" - JuliaCI/julia-coverage#v1: dirs: - src @@ -53,9 +52,7 @@ steps: agents: queue: "juliagpu" cuda: "*" - if: build.message !~ /\[skip tests\]/ && - build.message !~ /\[skip cuda\]/ && - !build.pull_request.draft + if: build.message !~ /\[skip tests\]/ && !build.pull_request.draft timeout_in_minutes: 120 matrix: setup: @@ -73,6 +70,34 @@ steps: echo -e "[CUDA_Runtime_jll]\nversion = \"{{matrix.cuda}}\"" >LocalPreferences.toml echo -e "[CUDA_Driver_jll]\ncompat = \"false\"" >>LocalPreferences.toml + - group: "Memory" + key: "memory" + depends_on: "julia" + steps: + - label: "CuArray with {{matrix.memory}} memory" + plugins: + - JuliaCI/julia#v1: + version: 1.9 + - JuliaCI/julia-test#v1: + test_args: "--quickfail core base libraries" + - JuliaCI/julia-coverage#v1: + dirs: + - src + - lib + - examples + agents: + queue: "juliagpu" + cuda: "*" + if: build.message !~ /\[skip tests\]/ && !build.pull_request.draft + timeout_in_minutes: 120 + matrix: + setup: + memory: + - "unified" + - "host" + commands: | + echo -e "[CUDA]\ndefault_memory = \"{{matrix.memory}}\"" >LocalPreferences.toml + - group: ":nesting_dolls: Subpackages" depends_on: "cuda" steps: @@ -104,9 +129,7 @@ steps: agents: queue: "juliagpu" cuda: "*" - if: build.message !~ /\[skip tests\]/ && - build.message !~ /\[skip subpackages\]/ && - !build.pull_request.draft + if: build.message !~ /\[skip tests\]/ && !build.pull_request.draft timeout_in_minutes: 120 commands: | julia --project -e ' @@ -165,9 +188,7 @@ steps: agents: queue: "juliagpu" cuda: "*" - if: build.message !~ /\[skip tests\]/ && - build.message !~ /\[skip downstream\]/ && - !build.pull_request.draft + if: build.message !~ /\[skip tests\]/ && !build.pull_request.draft timeout_in_minutes: 60 soft_fail: - exit_status: 3 @@ -240,9 +261,7 @@ steps: cuda: "*" env: JULIA_CUDA_USE_COMPAT: 'false' # NVIDIA bug #3418723: injection tools prevent probing libcuda - if: build.message !~ /\[skip tests\]/ && - build.message !~ /\[skip sanitizer\]/ && - !build.pull_request.draft + if: build.message !~ /\[skip tests\]/ && !build.pull_request.draft timeout_in_minutes: 10 # we want to benchmark every commit on the master branch, even if it failed CI @@ -274,9 +293,8 @@ steps: agents: queue: "juliagpu" cuda: "*" - if: build.message !~ /\[skip benchmarks\]/ && - build.branch !~ /^master$$/ && - !build.pull_request.draft + if: build.message !~ /\[skip benchmarks\]/ && !build.pull_request.draft && + build.branch !~ /^master$$/ timeout_in_minutes: 30 # if we will submit results, use the benchmark queue so that we will @@ -310,8 +328,7 @@ steps: queue: "benchmark" gpu: "rtx2070" cuda: "*" - if: build.message !~ /\[skip benchmarks\]/ && - build.branch =~ /^master$$/ + if: build.message !~ /\[skip benchmarks\]/ && build.branch =~ /^master$$/ matrix: setup: julia: diff --git a/LocalPreferences.toml b/LocalPreferences.toml index e0e7507033..513fc75593 100644 --- a/LocalPreferences.toml +++ b/LocalPreferences.toml @@ -12,6 +12,10 @@ # making it possible to do use cooperative multitasking. #nonblocking_synchronization = true +# which memory type unspecified allocations should default to. +# possible values: "device", "unified", "host" +#default_memory = "device" + [CUDA_Driver_jll] # whether to attempt to load a forwards-compatibile userspace driver. # only turn this off if you experience issues, e.g., when using a local diff --git a/lib/cusparse/array.jl b/lib/cusparse/array.jl index 0a5f2c0bbf..64f79eed82 100644 --- a/lib/cusparse/array.jl +++ b/lib/cusparse/array.jl @@ -417,9 +417,9 @@ Adapt.adapt_storage(::Type{CuArray}, xs::SparseMatrixCSC) = CuSparseMatrixCSC(xs Adapt.adapt_storage(::Type{CuArray{T}}, xs::SparseVector) where {T} = CuSparseVector{T}(xs) Adapt.adapt_storage(::Type{CuArray{T}}, xs::SparseMatrixCSC) where {T} = CuSparseMatrixCSC{T}(xs) -Adapt.adapt_storage(::CUDA.CuArrayAdaptor, xs::AbstractSparseArray) = +Adapt.adapt_storage(::CUDA.CuArrayKernelAdaptor, xs::AbstractSparseArray) = adapt(CuArray, xs) -Adapt.adapt_storage(::CUDA.CuArrayAdaptor, xs::AbstractSparseArray{<:AbstractFloat}) = +Adapt.adapt_storage(::CUDA.CuArrayKernelAdaptor, xs::AbstractSparseArray{<:AbstractFloat}) = adapt(CuArray{Float32}, xs) Adapt.adapt_storage(::Type{Array}, xs::CuSparseVector) = SparseVector(xs) @@ -546,7 +546,7 @@ end # interop with device arrays -function Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseVector) +function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseVector) return CuSparseDeviceVector( adapt(to, x.iPtr), adapt(to, x.nzVal), @@ -554,7 +554,7 @@ function Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseVector) ) end -function Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseMatrixCSR) +function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCSR) return CuSparseDeviceMatrixCSR( adapt(to, x.rowPtr), adapt(to, x.colVal), @@ -563,7 +563,7 @@ function Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseMatrixCSR) ) end -function Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseMatrixCSC) +function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCSC) return CuSparseDeviceMatrixCSC( adapt(to, x.colPtr), adapt(to, x.rowVal), @@ -572,7 +572,7 @@ function Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseMatrixCSC) ) end -function Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseMatrixBSR) +function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixBSR) return CuSparseDeviceMatrixBSR( adapt(to, x.rowPtr), adapt(to, x.colVal), @@ -582,7 +582,7 @@ function Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseMatrixBSR) ) end -function Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseMatrixCOO) +function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCOO) return CuSparseDeviceMatrixCOO( adapt(to, x.rowInd), adapt(to, x.colInd), diff --git a/src/array.jl b/src/array.jl index 3708c5ca3a..dffd661e5f 100644 --- a/src/array.jl +++ b/src/array.jl @@ -1,4 +1,4 @@ -export CuArray, CuVector, CuMatrix, CuVecOrMat, cu, is_unified +export CuArray, CuVector, CuMatrix, CuVecOrMat, cu, is_device, is_unified, is_host ## array type @@ -132,10 +132,23 @@ const CuVector{T} = CuArray{T,1} const CuMatrix{T} = CuArray{T,2} const CuVecOrMat{T} = Union{CuVector{T},CuMatrix{T}} -# default to non-unified memory +# unspecified memory allocation +const default_memory = let str = Preferences.@load_preference("default_memory", "device") + if str == "device" + Mem.DeviceBuffer + elseif str == "unified" + Mem.UnifiedBuffer + elseif str == "host" + Mem.HostBuffer + else + error("unknown default memory type: $default_memory") + end +end CuArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} = - CuArray{T,N,Mem.DeviceBuffer}(undef, dims) + CuArray{T,N,default_memory}(undef, dims) +is_device(a::CuArray) = isa(a.data[], Mem.DeviceBuffer) is_unified(a::CuArray) = isa(a.data[], Mem.UnifiedBuffer) +is_host(a::CuArray) = isa(a.data[], Mem.HostBuffer) # buffer, type and dimensionality specified CuArray{T,N,B}(::UndefInitializer, dims::NTuple{N,Integer}) where {T,N,B} = @@ -194,36 +207,47 @@ function Base.deepcopy_internal(x::CuArray, dict::IdDict) end +## unsafe_wrap + """ unsafe_wrap(CuArray, ptr::CuPtr{T}, dims; own=false, ctx=context()) -Wrap a `CuArray` object around the data at the address given by `ptr`. The pointer -element type `T` determines the array element type. `dims` is either an integer (for a 1d -array) or a tuple of the array dimensions. `own` optionally specified whether Julia should -take ownership of the memory, calling `cudaFree` when the array is no longer referenced. The -`ctx` argument determines the CUDA context where the data is allocated in. + # requires + unsafe_wrap(Array, a::CuArray) + + # requires HMM + unsafe_wrap(CuArray, ptr::ptr{T}, dims) + unsafe_wrap(CuArray, a::Array) + +Wrap a `CuArray` object around the data at the address given by the CUDA-managed pointer +`ptr`. The element type `T` determines the array element type. `dims` is either an integer +(for a 1d array) or a tuple of the array dimensions. `own` optionally specified whether +Julia should take ownership of the memory, calling `cudaFree` when the array is no longer +referenced. The `ctx` argument determines the CUDA context where the data is allocated in. """ +unsafe_wrap + +# managed pointer to CuArray function Base.unsafe_wrap(::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,N}}}, ptr::CuPtr{T}, dims::NTuple{N,Int}; own::Bool=false, ctx::CuContext=context()) where {T,N} - buf = _unsafe_wrap(T, ptr, dims; own, ctx) + buf = _unsafe_wrap_managed(T, ptr, dims; own, ctx) data = DataRef(own ? _free_buffer : (args...) -> (#= do nothing =#), buf) - CuArray{T, length(dims)}(data, dims) + CuArray{T,N}(data, dims) end function Base.unsafe_wrap(::Type{CuArray{T,N,B}}, ptr::CuPtr{T}, dims::NTuple{N,Int}; own::Bool=false, ctx::CuContext=context()) where {T,N,B} - buf = _unsafe_wrap(T, ptr, dims; own, ctx) + buf = _unsafe_wrap_managed(T, ptr, dims; own, ctx) if typeof(buf) !== B - error("Declared buffer type does not match inferred buffer type.") + throw(ArgumentError("Declared buffer type does not match inferred buffer type.")) end data = DataRef(own ? _free_buffer : (args...) -> (#= do nothing =#), buf) - CuArray{T, length(dims)}(data, dims) + CuArray{T,N}(data, dims) end - -function _unsafe_wrap(::Type{T}, ptr::CuPtr{T}, dims::NTuple{N,Int}; +function _unsafe_wrap_managed(::Type{T}, ptr::CuPtr{T}, dims::NTuple{N,Int}; own::Bool=false, ctx::CuContext=context()) where {T,N} - isbitstype(T) || error("Can only unsafe_wrap a pointer to a bits type") + isbitstype(T) || throw(ArgumentError("Can only unsafe_wrap a pointer to a bits type")) sz = prod(dims) * sizeof(T) # identify the buffer @@ -240,24 +264,84 @@ function _unsafe_wrap(::Type{T}, ptr::CuPtr{T}, dims::NTuple{N,Int}; error("Unknown memory type; please file an issue.") end catch err - error("Could not identify the buffer type; are you passing a valid CUDA pointer to unsafe_wrap?") + throw(ArgumentError("Could not identify the buffer type; are you passing a valid CUDA pointer to unsafe_wrap?")) end return buf end - -function Base.unsafe_wrap(Atype::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,1}}}, +# integer size input +function Base.unsafe_wrap(::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,1}}}, p::CuPtr{T}, dim::Int; own::Bool=false, ctx::CuContext=context()) where {T} - unsafe_wrap(Atype, p, (dim,); own, ctx) + unsafe_wrap(CuArray{T,1}, p, (dim,); own, ctx) end -function Base.unsafe_wrap(Atype::Type{CuArray{T,1,B}}, - p::CuPtr{T}, dim::Int; +function Base.unsafe_wrap(::Type{CuArray{T,1,B}}, p::CuPtr{T}, dim::Int; own::Bool=false, ctx::CuContext=context()) where {T,B} - unsafe_wrap(Atype, p, (dim,); own, ctx) + unsafe_wrap(CuArray{T,1,B}, p, (dim,); own, ctx) end -Base.unsafe_wrap(T::Type{<:CuArray}, ::Ptr, dims::NTuple{N,Int}; kwargs...) where {N} = - throw(ArgumentError("cannot wrap a CPU pointer with a $T")) +# managed pointer to Array +function Base.unsafe_wrap(::Union{Type{Array},Type{Array{T}},Type{Array{T,N}}}, + p::CuPtr{T}, dims::NTuple{N,Int}; + own::Bool=false) where {T,N} + if !is_managed(p) && memory_type(p) != CU_MEMORYTYPE_HOST + throw(ArgumentError("Can only create a CPU array object from a unified or host CUDA array")) + end + unsafe_wrap(Array{T,N}, reinterpret(Ptr{T}, p), dims; own) +end +# integer size input +function Base.unsafe_wrap(::Union{Type{Array},Type{Array{T}},Type{Array{T,1}}}, + p::CuPtr{T}, dim::Int; own::Bool=false) where {T} + unsafe_wrap(Array{T,1}, p, (dim,); own) +end +# array input +function Base.unsafe_wrap(::Union{Type{Array},Type{Array{T}},Type{Array{T,N}}}, + a::CuArray{T,N}) where {T,N} + p = pointer(a; type=Mem.Host) + unsafe_wrap(Array, p, size(a)) +end + +# unmanaged pointer to CuArray +function Base.unsafe_wrap(::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,N}}}, + p::Ptr{T}, dims::NTuple{N,Int}; ctx::CuContext=context()) where {T,N} + isbitstype(T) || throw(ArgumentError("Can only unsafe_wrap a pointer to a bits type")) + sz = prod(dims) * sizeof(T) + + if driver_version() < v"12.2" + error("Accessing host memory requires HMM support, which is only available in CUDA 12.2+ using the open-source driver.") + end + if attribute(device(), DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS) != 1 + error("Accessing host memory requires HMM support, which is not provided by your $(name(device())).") + end + + buf = Mem.UnifiedBuffer(ctx, reinterpret(CuPtr{Nothing}, p), sz) + data = DataRef((args...) -> (#= do nothing =#), buf) + CuArray{T,N}(data, dims) +end +function Base.unsafe_wrap(::Type{CuArray{T,N,B}}, p::Ptr{T}, dims::NTuple{N,Int}; + ctx::CuContext=context()) where {T,N,B} + if B !== Mem.UnifiedBuffer + throw(ArgumentError("Can only wrap an unmanaged pointer to a CuArray with a UnifiedBuffer")) + end + unsafe_wrap(CuArray{T,N}, p, dims; ctx) +end +# integer size input +function Base.unsafe_wrap(::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,1}}}, + p::Ptr{T}, dim::Int) where {T} + unsafe_wrap(CuArray{T,1}, p, (dim,)) +end +function Base.unsafe_wrap(::Type{CuArray{T,1,B}}, p::Ptr{T}, dim::Int) where {T,B} + unsafe_wrap(CuArray{T,1,B}, p, (dim,)) +end +# array input +function Base.unsafe_wrap(::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,N}}}, + a::Array{T,N}) where {T,N} + p = pointer(a) + unsafe_wrap(CuArray{T,N}, p, size(a)) +end +function Base.unsafe_wrap(::Type{CuArray{T,1,B}}, a::Array{T,1}) where {T,B} + p = pointer(a) + unsafe_wrap(CuArray{T,1,B}, p, size(a)) +end ## array interface @@ -299,9 +383,15 @@ const StridedCuVector{T} = StridedCuArray{T,1} const StridedCuMatrix{T} = StridedCuArray{T,2} const StridedCuVecOrMat{T} = Union{StridedCuVector{T}, StridedCuMatrix{T}} -Base.pointer(x::StridedCuArray{T}) where {T} = Base.unsafe_convert(CuPtr{T}, x) -@inline function Base.pointer(x::StridedCuArray{T}, i::Integer) where T - Base.unsafe_convert(CuPtr{T}, x) + Base._memory_offset(x, i) +@inline function Base.pointer(x::StridedCuArray{T}, i::Integer=1; type=Mem.Device) where T + PT = if type == Mem.Device + CuPtr{T} + elseif type == Mem.Host + Ptr{T} + else + error("unknown memory type") + end + Base.unsafe_convert(PT, x) + Base._memory_offset(x, i) end # anything that's (secretly) backed by a CuArray @@ -320,7 +410,7 @@ const AnyCuVecOrMat{T} = Union{AnyCuVector{T}, AnyCuMatrix{T}} end @inline CuArray{T,N}(xs::AbstractArray{<:Any,N}) where {T,N} = - CuArray{T,N,Mem.Device}(xs) + CuArray{T,N,default_memory}(xs) @inline CuArray{T,N}(xs::CuArray{<:Any,N,B}) where {T,N,B} = CuArray{T,N,B}(xs) @@ -340,12 +430,64 @@ CuArray{T,N}(xs::CuArray{T,N,B}) where {T,N,B} = copy(xs) Base.convert(::Type{T}, x::T) where T <: CuArray = x -## interop with C libraries +## interop with libraries + +# when CPU-accessible buffers are converted to a device pointer, we assume it will be +# accessed asynchronously. we keep track of that in the task local storage, and use that +# information to perform additional synchronization when converting the buffer to a host +# pointer. TODO: optimize this! it currently halves the performance of scalar indexing. +function mark_async(buf::Union{Mem.HostBuffer,Mem.UnifiedBuffer}) + ptr = convert(Ptr{Nothing}, buf) + tls = task_local_storage() + if haskey(tls, :CUDA_ASYNC_BUFFERS) + async_buffers = tls[:CUDA_ASYNC_BUFFERS]::Vector{Ptr{Nothing}} + in(ptr, async_buffers) && return + pushfirst!(async_buffers, ptr) + else + tls[:CUDA_ASYNC_BUFFERS] = [ptr] + end + return +end +function ensure_sync(buf::Union{Mem.HostBuffer,Mem.UnifiedBuffer}) + tls = task_local_storage() + haskey(tls, :CUDA_ASYNC_BUFFERS) || return + async_buffers = tls[:CUDA_ASYNC_BUFFERS]::Vector{Ptr{Nothing}} + ptr = convert(Ptr{Nothing}, buf) + in(ptr, async_buffers) || return + synchronize() + filter!(!isequal(ptr), async_buffers) + return +end -Base.unsafe_convert(::Type{Ptr{T}}, x::CuArray{T}) where {T} = - throw(ArgumentError("cannot take the CPU address of a $(typeof(x))")) function Base.unsafe_convert(::Type{CuPtr{T}}, x::CuArray{T}) where {T} - convert(CuPtr{T}, x.data[]) + x.offset*Base.elsize(x) + buf = x.data[] + if is_unified(x) || is_host(x) + mark_async(buf) + end + convert(CuPtr{T}, buf) + x.offset*Base.elsize(x) +end + +function Base.unsafe_convert(::Type{Ptr{T}}, x::CuArray{T}) where {T} + buf = x.data[] + if is_device(x) + throw(ArgumentError("cannot take the CPU address of a $(typeof(x))")) + elseif is_unified(x) || is_host(x) + ensure_sync(buf) + end + convert(Ptr{T}, buf) + x.offset*Base.elsize(x) +end + + +## indexing + +function Base.getindex(x::CuArray{<:Any, <:Any, <:Union{Mem.Host,Mem.Unified}}, I::Int) + @boundscheck checkbounds(x, I) + unsafe_load(pointer(x, I; type=Mem.Host)) +end + +function Base.setindex!(x::CuArray{<:Any, <:Any, <:Union{Mem.Host,Mem.Unified}}, v, I::Int) + @boundscheck checkbounds(x, I) + unsafe_store!(pointer(x, I; type=Mem.Host), v) end @@ -360,8 +502,16 @@ end ## memory copying typetagdata(a::Array, i=1) = ccall(:jl_array_typetagdata, Ptr{UInt8}, (Any,), a) + i - 1 -typetagdata(a::CuArray, i=1) = - convert(CuPtr{UInt8}, a.data[]) + a.maxsize + a.offset + i - 1 +function typetagdata(a::CuArray, i=1; type=Mem.Device) + PT = if type == Mem.Device + CuPtr{UInt8} + elseif type == Mem.Host + Ptr{UInt8} + else + error("unknown memory type") + end + convert(PT, a.data[]) + a.maxsize + a.offset + i - 1 +end function Base.copyto!(dest::DenseCuArray{T}, doffs::Integer, src::Array{T}, soffs::Integer, n::Integer) where T @@ -476,11 +626,11 @@ function Base.unsafe_copyto!(dest::DenseCuArray{T,<:Any,<:Union{Mem.UnifiedBuffe synchronize() GC.@preserve src dest begin - cpu_ptr = pointer(src, soffs) - unsafe_copyto!(host_pointer(pointer(dest, doffs)), cpu_ptr, n) + ptr = pointer(src, soffs) + unsafe_copyto!(pointer(dest, doffs; type=Mem.Host), ptr, n) if Base.isbitsunion(T) - cpu_ptr = typetagdata(src, soffs) - unsafe_copyto!(host_pointer(typetagdata(dest, doffs)), cpu_ptr, n) + ptr = typetagdata(src, soffs) + unsafe_copyto!(typetagdata(dest, doffs; type=Mem.Host), ptr, n) end end return dest @@ -492,11 +642,11 @@ function Base.unsafe_copyto!(dest::Array{T}, doffs, synchronize() GC.@preserve src dest begin - cpu_ptr = pointer(dest, doffs) - unsafe_copyto!(cpu_ptr, host_pointer(pointer(src, soffs)), n) + ptr = pointer(dest, doffs) + unsafe_copyto!(ptr, pointer(src, soffs; type=Mem.Host), n) if Base.isbitsunion(T) - cpu_ptr = typetagdata(dest, doffs) - unsafe_copyto!(cpu_ptr, host_pointer(typetagdata(src, soffs)), n) + ptr = typetagdata(dest, doffs) + unsafe_copyto!(ptr, typetagdata(src, soffs; type=Mem.Host), n) end end @@ -564,19 +714,19 @@ Adapt.adapt_storage(::Type{<:CuArray{T, N, B}}, xs::AT) where {T, N, B, AT<:Abst # eagerly converts Float64 to Float32, for performance reasons -struct CuArrayAdaptor{B} end +struct CuArrayKernelAdaptor{B} end -Adapt.adapt_storage(::CuArrayAdaptor{B}, xs::AbstractArray{T,N}) where {T,N,B} = +Adapt.adapt_storage(::CuArrayKernelAdaptor{B}, xs::AbstractArray{T,N}) where {T,N,B} = isbits(xs) ? xs : CuArray{T,N,B}(xs) -Adapt.adapt_storage(::CuArrayAdaptor{B}, xs::AbstractArray{T,N}) where {T<:AbstractFloat,N,B} = +Adapt.adapt_storage(::CuArrayKernelAdaptor{B}, xs::AbstractArray{T,N}) where {T<:AbstractFloat,N,B} = isbits(xs) ? xs : CuArray{Float32,N,B}(xs) -Adapt.adapt_storage(::CuArrayAdaptor{B}, xs::AbstractArray{T,N}) where {T<:Complex{<:AbstractFloat},N,B} = +Adapt.adapt_storage(::CuArrayKernelAdaptor{B}, xs::AbstractArray{T,N}) where {T<:Complex{<:AbstractFloat},N,B} = isbits(xs) ? xs : CuArray{ComplexF32,N,B}(xs) # not for Float16 -Adapt.adapt_storage(::CuArrayAdaptor{B}, xs::AbstractArray{T,N}) where {T<:Union{Float16,BFloat16},N,B} = +Adapt.adapt_storage(::CuArrayKernelAdaptor{B}, xs::AbstractArray{T,N}) where {T<:Union{Float16,BFloat16},N,B} = isbits(xs) ? xs : CuArray{T,N,B}(xs) """ @@ -621,7 +771,21 @@ julia> CuArray(1:3) 3 ``` """ -@inline cu(xs; unified::Bool=false) = adapt(CuArrayAdaptor{unified ? Mem.UnifiedBuffer : Mem.DeviceBuffer}(), xs) +@inline function cu(xs; device::Bool=false, unified::Bool=false, host::Bool=false) + if device + unified + host > 1 + throw(ArgumentError("Can only specify one of `device`, `unified`, or `host`")) + end + memory = if device + Mem.DeviceBuffer + elseif unified + Mem.UnifiedBuffer + elseif host + Mem.HostBuffer + else + default_memory + end + adapt(CuArrayKernelAdaptor{memory}(), xs) +end Base.getindex(::typeof(cu), xs...) = CuArray([xs...]) diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 9276876026..177da0efa2 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -121,35 +121,44 @@ end ## host to device value conversion -struct Adaptor end +struct KernelAdaptor end # convert CUDA host pointers to device pointers # TODO: use ordinary ptr? -Adapt.adapt_storage(to::Adaptor, p::CuPtr{T}) where {T} = reinterpret(LLVMPtr{T,AS.Generic}, p) +Adapt.adapt_storage(to::KernelAdaptor, p::CuPtr{T}) where {T} = + reinterpret(LLVMPtr{T,AS.Generic}, p) + +# convert CUDA host arrays to device arrays +function Adapt.adapt_storage(::KernelAdaptor, xs::DenseCuArray{T,N}) where {T,N} + # prefetch unified memory as we're likely to use it on the GPU + # TODO: make this configurable? + if is_unified(xs) && sizeof(xs) > 0 && !is_capturing() + buf = xs.data[] + subbuf = Mem.UnifiedBuffer(buf.ctx, pointer(xs), sizeof(xs)) + Mem.prefetch(subbuf) + end + + Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs) +end # Base.RefValue isn't GPU compatible, so provide a compatible alternative struct CuRefValue{T} <: Ref{T} x::T end Base.getindex(r::CuRefValue{T}) where T = r.x -Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = CuRefValue(adapt(to, r[])) +Adapt.adapt_structure(to::KernelAdaptor, r::Base.RefValue) = CuRefValue(adapt(to, r[])) # broadcast sometimes passes a ref(type), resulting in a GPU-incompatible DataType box. # avoid that by using a special kind of ref that knows about the boxed type. struct CuRefType{T} <: Ref{DataType} end Base.getindex(r::CuRefType{T}) where T = T -Adapt.adapt_structure(to::Adaptor, r::Base.RefValue{<:Union{DataType,Type}}) = CuRefType{r[]}() +Adapt.adapt_structure(to::KernelAdaptor, r::Base.RefValue{<:Union{DataType,Type}}) = + CuRefType{r[]}() # case where type is the function being broadcasted -Adapt.adapt_structure(to::Adaptor, bc::Base.Broadcast.Broadcasted{Style, <:Any, Type{T}}) where {Style, T} = - Base.Broadcast.Broadcasted{Style}((x...) -> T(x...), adapt(to, bc.args), bc.axes) - -Adapt.adapt_storage(::Adaptor, xs::CuArray{T,N}) where {T,N} = - Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs) - -# we materialize ReshapedArray/ReinterpretArray/SubArray/... directly as a device array -Adapt.adapt_structure(::Adaptor, xs::DenseCuArray{T,N}) where {T,N} = - Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs) +Adapt.adapt_structure(to::KernelAdaptor, + bc::Broadcast.Broadcasted{Style, <:Any, Type{T}}) where {Style, T} = + Broadcast.Broadcasted{Style}((x...) -> T(x...), adapt(to, bc.args), bc.axes) """ cudaconvert(x) @@ -159,9 +168,9 @@ converted to a GPU-friendly format. By default, the function does nothing and re input object `x` as-is. Do not add methods to this function, but instead extend the underlying Adapt.jl package and -register methods for the the `CUDA.Adaptor` type. +register methods for the the `CUDA.KernelAdaptor` type. """ -cudaconvert(arg) = adapt(Adaptor(), arg) +cudaconvert(arg) = adapt(KernelAdaptor(), arg) ## abstract kernel functionality diff --git a/src/pool.jl b/src/pool.jl index 11b9aee724..9d476d591f 100644 --- a/src/pool.jl +++ b/src/pool.jl @@ -451,6 +451,12 @@ end end buf, time end +@inline function _alloc(::Type{Mem.HostBuffer}, sz; stream::Union{Nothing,CuStream}) + time = Base.@elapsed begin + buf = Mem.alloc(Mem.Host, sz) + end + buf, time +end """ free(buf) diff --git a/src/texture.jl b/src/texture.jl index 24f9da6fc9..4f1a9444eb 100644 --- a/src/texture.jl +++ b/src/texture.jl @@ -319,6 +319,6 @@ memory_source(::Any) = error("Unknown texture source $(typeof(t))") memory_source(::CuArray) = LinearMemory() memory_source(::CuTextureArray) = ArrayMemory() -Adapt.adapt_storage(::Adaptor, t::CuTexture{T,N}) where {T,N} = +Adapt.adapt_storage(::KernelAdaptor, t::CuTexture{T,N}) where {T,N} = CuDeviceTexture{T,N,typeof(memory_source(parent(t))), t.normalized_coordinates, typeof(t.interpolation)}(size(t), t.handle) diff --git a/src/utilities.jl b/src/utilities.jl index 9a447fe49f..29972e7741 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -110,6 +110,23 @@ function versioninfo(io::IO=stdout) println(io) end + prefs = [ + "nonblocking_synchronization" => Preferences.load_preference(CUDA, "nonblocking_synchronization"), + "default_memory" => Preferences.load_preference(CUDA, "default_memory"), + "CUDA_Runtime_jll.version" => Preferences.load_preference(CUDA_Runtime_jll, "version"), + "CUDA_Runtime_jll.local" => Preferences.load_preference(CUDA_Runtime_jll, "local"), + "CUDA_Driver_jll.compat" => Preferences.load_preference(CUDA_Driver_jll, "compat"), + ] + if any(x->!isnothing(x[2]), prefs) + println(io, "Preferences:") + for (key, val) in prefs + if !isnothing(val) + println(io, "- $key: $val") + end + end + println(io) + end + devs = devices() if isempty(devs) println(io, "No CUDA-capable devices.") diff --git a/test/base/array.jl b/test/base/array.jl index f7c3c77bef..f244cfa8e0 100644 --- a/test/base/array.jl +++ b/test/base/array.jl @@ -3,67 +3,31 @@ import Adapt using ChainRulesCore: add!!, is_inplaceable_destination @testset "constructors" begin - xs = CuArray{Int}(undef, 2, 3) - @test device(xs) == device() - @test context(xs) == context() - @test collect(CuArray([1 2; 3 4])) == [1 2; 3 4] - @test collect(cu[1, 2, 3]) == [1, 2, 3] - @test collect(cu([1, 2, 3])) == [1, 2, 3] - @test testf(vec, rand(5,3)) - @test cu(1:3) === 1:3 - @test Base.elsize(xs) == sizeof(Int) - @test pointer(CuArray{Int, 2}(xs)) != pointer(xs) - - # test aggressive conversion to Float32, but only for floats, and only with `cu` - @test cu([1]) isa CuArray{Int} - @test cu(Float64[1]) isa CuArray{Float32} - @test cu(ComplexF64[1+1im]) isa CuArray{ComplexF32} - @test Adapt.adapt(CuArray, Float64[1]) isa CuArray{Float64} - @test Adapt.adapt(CuArray, ComplexF64[1]) isa CuArray{ComplexF64} - @test Adapt.adapt(CuArray{Float16}, Float64[1]) isa CuArray{Float16} - - @test_throws ArgumentError Base.unsafe_convert(Ptr{Int}, xs) - @test_throws ArgumentError Base.unsafe_convert(Ptr{Float32}, xs) - - # unsafe_wrap - let - arr = CuArray{Int}(undef, 2) - ptr = pointer(arr) - B = Mem.DeviceBuffer - - ## compare the fields we care about - function test_eq(a, T, N, dims) - @test eltype(a) == T - @test ndims(a) == N - @test a.data[].ptr == ptr - @test a.data[].ctx == context() - @test a.maxsize == arr.maxsize - @test a.offset == arr.offset - @test a.dims == dims - end - - test_eq(unsafe_wrap(CuArray, ptr, 2), Int, 1, (2,)) - test_eq(unsafe_wrap(CuArray{Int}, ptr, 2), Int, 1, (2,)) - test_eq(unsafe_wrap(CuArray{Int,1}, ptr, 2), Int, 1, (2,)) - test_eq(unsafe_wrap(CuArray{Int,1,B}, ptr, 2), Int, 1, (2,)) - test_eq(unsafe_wrap(CuArray, ptr, (1,2)), Int, 2, (1,2)) - test_eq(unsafe_wrap(CuArray{Int}, ptr, (1,2)), Int, 2, (1,2)) - test_eq(unsafe_wrap(CuArray{Int,2}, ptr, (1,2)), Int, 2, (1,2)) - test_eq(unsafe_wrap(CuArray{Int,2,B}, ptr, (1,2)), Int, 2, (1,2)) - - @test_throws ErrorException unsafe_wrap(CuArray{Int,1,Mem.HostBuffer}, ptr, 2) - @test_throws ErrorException unsafe_wrap(CuArray{Int,2,Mem.HostBuffer}, ptr, (1,2)) - end - let buf = Mem.alloc(Mem.Host, sizeof(Int), Mem.HOSTALLOC_DEVICEMAP) - gpu_ptr = convert(CuPtr{Int}, buf) - gpu_arr = unsafe_wrap(CuArray, gpu_ptr, 1) - gpu_arr .= 42 - - synchronize() - - cpu_ptr = convert(Ptr{Int}, buf) - cpu_arr = unsafe_wrap(Array, cpu_ptr, 1) - @test cpu_arr == [42] + let xs = CuArray{Int}(undef, 2, 3) + # basic properties + @test device(xs) == device() + @test context(xs) == context() + @test collect(CuArray([1 2; 3 4])) == [1 2; 3 4] + @test collect(cu[1, 2, 3]) == [1, 2, 3] + @test collect(cu([1, 2, 3])) == [1, 2, 3] + @test testf(vec, rand(5,3)) + @test cu(1:3) === 1:3 + @test Base.elsize(xs) == sizeof(Int) + @test pointer(CuArray{Int, 2}(xs)) != pointer(xs) + + # test aggressive conversion to Float32, but only for floats, and only with `cu` + @test cu([1]) isa CuArray{Int} + @test cu(Float64[1]) isa CuArray{Float32} + @test cu(ComplexF64[1+1im]) isa CuArray{ComplexF32} + @test Adapt.adapt(CuArray, Float64[1]) isa CuArray{Float64} + @test Adapt.adapt(CuArray, ComplexF64[1]) isa CuArray{ComplexF64} + @test Adapt.adapt(CuArray{Float16}, Float64[1]) isa CuArray{Float16} + end + + # test pointer conversions + let xs = CuVector{Int,Mem.DeviceBuffer}(undef, 1) + @test_throws ArgumentError Base.unsafe_convert(Ptr{Int}, xs) + @test_throws ArgumentError Base.unsafe_convert(Ptr{Float32}, xs) end @test collect(CUDA.zeros(2, 2)) == zeros(Float32, 2, 2) @@ -87,6 +51,69 @@ using ChainRulesCore: add!!, is_inplaceable_destination end end +@testset "unsafe_wrap" begin + hmm = CUDA.driver_version() >= v"12.2" && + attribute(device(), CUDA.DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS) == 1 + + # managed memory -> CuArray + for a in [cu([1]; device=true), cu([1]; unified=true)] + p = pointer(a) + for AT in [CuArray, CuArray{Int}, CuArray{Int,1}, typeof(a)], + b in [unsafe_wrap(AT, p, 1), unsafe_wrap(AT, p, (1,))] + @test typeof(b) == typeof(a) + @test pointer(b) == p + @test size(b) == (1,) + end + end + + # managed memory -> Array + let a = cu([1]; unified=true) + p = pointer(a) + for AT in [Array, Array{Int}, Array{Int,1}], + b in [unsafe_wrap(AT, p, 1), unsafe_wrap(AT, p, (1,)), unsafe_wrap(AT, a)] + @test typeof(b) == Array{Int,1} + @test pointer(b) == reinterpret(Ptr{Int}, p) + @test size(b) == (1,) + end + end + + # unmanaged memory -> CuArray + if hmm + a = [1] + p = pointer(a) + for AT in [CuArray, CuArray{Int}, CuArray{Int,1}, CuArray{Int,1,Mem.UnifiedBuffer}], + b in [unsafe_wrap(AT, p, 1), unsafe_wrap(AT, p, (1,)), unsafe_wrap(AT, a)] + @test typeof(b) == CuArray{Int,1,Mem.UnifiedBuffer} + @test pointer(b) == reinterpret(CuPtr{Int}, p) + @test size(b) == (1,) + end + end + + # errors + let a = cu([1]; device=true) + @test_throws ArgumentError unsafe_wrap(Array, a) + @test_throws ArgumentError unsafe_wrap(CuArray{Int,1,Mem.UnifiedBuffer}, pointer(a), 1) + end + if hmm + let a = [1] + @test_throws ArgumentError unsafe_wrap(CuArray{Int,1,Mem.DeviceBuffer}, a) + end + end + + # some actual operations + let buf = Mem.alloc(Mem.Host, sizeof(Int), Mem.HOSTALLOC_DEVICEMAP) + gpu_ptr = convert(CuPtr{Int}, buf) + gpu_arr = unsafe_wrap(CuArray, gpu_ptr, 1) + gpu_arr .= 42 + + synchronize() + + cpu_ptr = convert(Ptr{Int}, buf) + cpu_arr = unsafe_wrap(Array, cpu_ptr, 1) + @test cpu_arr == [42] + end +end + @testset "adapt" begin A = rand(Float32, 3, 3) dA = CuArray(A) @@ -666,7 +693,7 @@ end dev = device() let - a = CuVector{Int}(undef, 1) + a = CuVector{Int,Mem.DeviceBuffer}(undef, 1) @test !is_unified(a) @test !is_managed(pointer(a)) end @@ -689,12 +716,6 @@ end end let - # default ctor: device memory - let a = CUDA.rand(1) - @test !is_unified(a) - @test !is_managed(pointer(a)) - end - for B = [Mem.DeviceBuffer, Mem.UnifiedBuffer] a = CuVector{Float32,B}(rand(Float32, 1)) @test !xor(B == Mem.UnifiedBuffer, is_unified(a)) @@ -735,12 +756,12 @@ end end # cu: supports unified keyword - let a = cu(rand(Float64, 1); unified=true) - @test is_unified(a) + let a = cu(rand(Float64, 1); device=true) + @test !is_unified(a) @test eltype(a) == Float32 end - let a = cu(rand(Float64, 1)) - @test !is_unified(a) + let a = cu(rand(Float64, 1); unified=true) + @test is_unified(a) @test eltype(a) == Float32 end end @@ -863,4 +884,4 @@ end c = add!!(a, b) @test c == a′ + b @test c === a -end \ No newline at end of file +end diff --git a/test/core/execution.jl b/test/core/execution.jl index cacdea8e6e..e6e13e91a1 100644 --- a/test/core/execution.jl +++ b/test/core/execution.jl @@ -470,7 +470,7 @@ end @eval struct Host end @eval struct Device end - Adapt.adapt_storage(::CUDA.Adaptor, a::Host) = Device() + Adapt.adapt_storage(::CUDA.KernelAdaptor, a::Host) = Device() Base.convert(::Type{Int}, ::Host) = 1 Base.convert(::Type{Int}, ::Device) = 2 diff --git a/test/runtests.jl b/test/runtests.jl index e02a00b0e0..b744bea2cc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -58,6 +58,10 @@ const tests = ["core$(path_separator)initialization"] # needs to run first const test_runners = Dict() ## GPUArrays testsuite for name in keys(TestSuite.tests) + if CUDA.default_memory != Mem.Device && name == "indexing scalar" + # GPUArrays' scalar indexing tests assume that indexing is not supported + continue + end push!(tests, "gpuarrays$(path_separator)$name") test_runners["gpuarrays$(path_separator)$name"] = ()->TestSuite.tests[name](CuArray) end