Skip to content

Commit

Permalink
Remove the N argument from GPUArrays.derive. (#508)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Dec 19, 2023
1 parent b2c6998 commit 3b7a1ac
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
2 changes: 1 addition & 1 deletion lib/JLArrays/src/JLArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ function typed_data(x::JLArray{T}) where {T}
unsafe_wrap(Array, pointer(x), x.dims)
end

function GPUArrays.derive(::Type{T}, N::Int, a::JLArray, dims::Dims, offset::Int) where {T}
function GPUArrays.derive(::Type{T}, a::JLArray, dims::Dims{N}, offset::Int) where {T,N}
ref = copy(a.data)
offset = (a.offset * Base.elsize(a)) ÷ sizeof(T) + offset
JLArray{T,N}(ref, dims; offset)
Expand Down
17 changes: 8 additions & 9 deletions src/host/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ function Base.reshape(a::AbstractGPUArray{T,M}, dims::NTuple{N,Int}) where {T,N,
return a
end

derive(T, N, a, dims, 0)
derive(T, a, dims, 0)
end


Expand All @@ -173,7 +173,7 @@ function Base.reinterpret(::Type{T}, a::AbstractGPUArray{S,N}) where {T,S,N}
osize = tuple(size1, Base.tail(isize)...)
end

return derive(T, N, a, osize, 0)
return derive(T, a, osize, 0)
end

function _reinterpret_exception(::Type{T}, a::AbstractArray{S,N}) where {T,S,N}
Expand Down Expand Up @@ -229,8 +229,8 @@ end
## reinterpret(reshape)

function Base.reinterpret(::typeof(reshape), ::Type{T}, a::AbstractGPUArray) where {T}
N, osize = _base_check_reshape_reinterpret(T, a)
return derive(T, N, a, osize, 0)
osize = _base_check_reshape_reinterpret(T, a)
return derive(T, a, osize, 0)
end

# taken from reinterpretarray.jl
Expand All @@ -240,21 +240,20 @@ function _base_check_reshape_reinterpret(::Type{T}, a::AbstractGPUArray{S}) wher
isbitstype(S) || throwbits(S, T, S)
if sizeof(S) == sizeof(T)
N = ndims(a)
osize = size(a)
size(a)
elseif sizeof(S) > sizeof(T)
d, r = divrem(sizeof(S), sizeof(T))
r == 0 || throwintmult(S, T)
N = ndims(a) + 1
osize = (d, size(a)...)
(d, size(a)...)
else
d, r = divrem(sizeof(T), sizeof(S))
r == 0 || throwintmult(S, T)
N = ndims(a) - 1
N > -1 || throwsize0(S, T, "larger")
axes(a, 1) == Base.OneTo(sizeof(T) ÷ sizeof(S)) || throwsize1(a, T)
osize = size(a)[2:end]
size(a)[2:end]
end
return N, osize
end

@noinline function throwbits(S::Type, T::Type, U::Type)
Expand Down Expand Up @@ -321,7 +320,7 @@ end
@inline function unsafe_contiguous_view(a::AbstractGPUArray{T}, I::NTuple{N,Base.ViewIndex}, dims::NTuple{M,Integer}) where {T,N,M}
offset = Base.compute_offset1(a, 1, I)

derive(T, M, a, dims, offset)
derive(T, a, dims, offset)
end

@inline function unsafe_view(A, I, ::NonContiguous)
Expand Down
2 changes: 1 addition & 1 deletion src/host/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,5 @@ end
# size, but backed by the same data. The `additional_offset` is the number of elements
# to offset the new array from the original array.

derive(::Type, N::Int, a::AbstractGPUArray, osize::Dims, additional_offset::Int) =
derive(::Type, a::AbstractGPUArray, osize::Dims, additional_offset::Int) =
error("Not implemented")

0 comments on commit 3b7a1ac

Please sign in to comment.