From 3b7a1ac142e741d100807ec425088131a3e92c66 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 19 Dec 2023 13:30:46 +0100 Subject: [PATCH] Remove the N argument from GPUArrays.derive. (#508) --- lib/JLArrays/src/JLArrays.jl | 2 +- src/host/base.jl | 17 ++++++++--------- src/host/construction.jl | 2 +- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index f10d7176..0b3170a3 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -190,7 +190,7 @@ function typed_data(x::JLArray{T}) where {T} unsafe_wrap(Array, pointer(x), x.dims) end -function GPUArrays.derive(::Type{T}, N::Int, a::JLArray, dims::Dims, offset::Int) where {T} +function GPUArrays.derive(::Type{T}, a::JLArray, dims::Dims{N}, offset::Int) where {T,N} ref = copy(a.data) offset = (a.offset * Base.elsize(a)) ÷ sizeof(T) + offset JLArray{T,N}(ref, dims; offset) diff --git a/src/host/base.jl b/src/host/base.jl index 20cfd563..28f7166a 100644 --- a/src/host/base.jl +++ b/src/host/base.jl @@ -155,7 +155,7 @@ function Base.reshape(a::AbstractGPUArray{T,M}, dims::NTuple{N,Int}) where {T,N, return a end - derive(T, N, a, dims, 0) + derive(T, a, dims, 0) end @@ -173,7 +173,7 @@ function Base.reinterpret(::Type{T}, a::AbstractGPUArray{S,N}) where {T,S,N} osize = tuple(size1, Base.tail(isize)...) end - return derive(T, N, a, osize, 0) + return derive(T, a, osize, 0) end function _reinterpret_exception(::Type{T}, a::AbstractArray{S,N}) where {T,S,N} @@ -229,8 +229,8 @@ end ## reinterpret(reshape) function Base.reinterpret(::typeof(reshape), ::Type{T}, a::AbstractGPUArray) where {T} - N, osize = _base_check_reshape_reinterpret(T, a) - return derive(T, N, a, osize, 0) + osize = _base_check_reshape_reinterpret(T, a) + return derive(T, a, osize, 0) end # taken from reinterpretarray.jl @@ -240,21 +240,20 @@ function _base_check_reshape_reinterpret(::Type{T}, a::AbstractGPUArray{S}) wher isbitstype(S) || throwbits(S, T, S) if sizeof(S) == sizeof(T) N = ndims(a) - osize = size(a) + size(a) elseif sizeof(S) > sizeof(T) d, r = divrem(sizeof(S), sizeof(T)) r == 0 || throwintmult(S, T) N = ndims(a) + 1 - osize = (d, size(a)...) + (d, size(a)...) else d, r = divrem(sizeof(T), sizeof(S)) r == 0 || throwintmult(S, T) N = ndims(a) - 1 N > -1 || throwsize0(S, T, "larger") axes(a, 1) == Base.OneTo(sizeof(T) ÷ sizeof(S)) || throwsize1(a, T) - osize = size(a)[2:end] + size(a)[2:end] end - return N, osize end @noinline function throwbits(S::Type, T::Type, U::Type) @@ -321,7 +320,7 @@ end @inline function unsafe_contiguous_view(a::AbstractGPUArray{T}, I::NTuple{N,Base.ViewIndex}, dims::NTuple{M,Integer}) where {T,N,M} offset = Base.compute_offset1(a, 1, I) - derive(T, M, a, dims, offset) + derive(T, a, dims, offset) end @inline function unsafe_view(A, I, ::NonContiguous) diff --git a/src/host/construction.jl b/src/host/construction.jl index f2f21f67..a456606b 100644 --- a/src/host/construction.jl +++ b/src/host/construction.jl @@ -140,5 +140,5 @@ end # size, but backed by the same data. The `additional_offset` is the number of elements # to offset the new array from the original array. -derive(::Type, N::Int, a::AbstractGPUArray, osize::Dims, additional_offset::Int) = +derive(::Type, a::AbstractGPUArray, osize::Dims, additional_offset::Int) = error("Not implemented")