Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Abstract types for sparse arrays #577

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ Manifest.toml

# MacOS generated files
*.DS_Store

/.vscode/
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Expand All @@ -24,5 +25,6 @@ Printf = "1"
Random = "1"
Reexport = "1"
Serialization = "1"
SparseArrays = "1"
Statistics = "1"
julia = "1.10"
14 changes: 0 additions & 14 deletions lib/GPUArraysCore/Manifest.toml

This file was deleted.

4 changes: 4 additions & 0 deletions lib/GPUArraysCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ version = "0.2.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
Adapt = "4.0"
LinearAlgebra = "1"
SparseArrays = "1"
julia = "1.6"
51 changes: 35 additions & 16 deletions lib/GPUArraysCore/src/GPUArraysCore.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
module GPUArraysCore

using Adapt

using LinearAlgebra
using SparseArrays

## essential types

export AbstractGPUArray, AbstractGPUVector, AbstractGPUMatrix, AbstractGPUVecOrMat,
WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle,
AnyGPUArray, AnyGPUVector, AnyGPUMatrix
WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle,
AnyGPUArray, AnyGPUVector, AnyGPUMatrix

export AbstractGPUSparseArray, AbstractGPUSparseMatrix, AbstractGPUSparseVector, AbstractGPUSparseVecOrMat,
AbstractGPUSparseMatrixCSC, AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCOO, AnyGPUSparseMatrixCSC, AnyGPUSparseMatrixCSR, AnyGPUSparseMatrixCOO

"""
AbstractGPUArray{T, N} <: DenseArray{T, N}
Expand All @@ -16,18 +20,33 @@ Supertype for `N`-dimensional GPU arrays (or array-like types) with elements of
Instances of this type are expected to live on the host, see [`AbstractDeviceArray`](@ref)
for device-side objects.
"""
abstract type AbstractGPUArray{T, N} <: DenseArray{T, N} end
abstract type AbstractGPUArray{T,N} <: DenseArray{T,N} end

const AbstractGPUVector{T} = AbstractGPUArray{T, 1}
const AbstractGPUMatrix{T} = AbstractGPUArray{T, 2}
const AbstractGPUVecOrMat{T} = Union{AbstractGPUArray{T, 1}, AbstractGPUArray{T, 2}}
const AbstractGPUVector{T} = AbstractGPUArray{T,1}
const AbstractGPUMatrix{T} = AbstractGPUArray{T,2}
const AbstractGPUVecOrMat{T} = Union{AbstractGPUArray{T,1},AbstractGPUArray{T,2}}

# convenience aliases for working with wrapped arrays
const WrappedGPUArray{T,N} = WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{T,N}}
const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}}
const AnyGPUVector{T} = AnyGPUArray{T, 1}
const AnyGPUMatrix{T} = AnyGPUArray{T, 2}
const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N},WrappedGPUArray{T,N}}
const AnyGPUVector{T} = AnyGPUArray{T,1}
const AnyGPUMatrix{T} = AnyGPUArray{T,2}

## sparse arrays

abstract type AbstractGPUSparseArray{Tv,Ti,N} <: AbstractSparseArray{Tv,Ti,N} end

const AbstractGPUSparseMatrix{Tv,Ti} = AbstractGPUSparseArray{Tv,Ti,2}
const AbstractGPUSparseVector{Tv,Ti} = AbstractGPUSparseArray{Tv,Ti,1}
const AbstractGPUSparseVecOrMat{Tv,Ti} = Union{AbstractGPUSparseVector{Tv,Ti},AbstractGPUSparseMatrix{Tv,Ti}}

abstract type AbstractGPUSparseMatrixCSC{Tv,Ti<:Integer} <: AbstractGPUSparseMatrix{Tv,Ti} end
abstract type AbstractGPUSparseMatrixCSR{Tv,Ti<:Integer} <: AbstractGPUSparseMatrix{Tv,Ti} end
abstract type AbstractGPUSparseMatrixCOO{Tv,Ti<:Integer} <: AbstractGPUSparseMatrix{Tv,Ti} end

const AnyGPUSparseMatrixCSC{Tv,Ti} = Union{AbstractGPUSparseMatrixCSC{Tv,Ti},Transpose{Tv,<:AbstractGPUSparseMatrixCSC{Tv,Ti}},Adjoint{Tv,<:AbstractGPUSparseMatrixCSC{Tv,Ti}}}
const AnyGPUSparseMatrixCSR{Tv,Ti} = Union{AbstractGPUSparseMatrixCSR{Tv,Ti},Transpose{Tv,<:AbstractGPUSparseMatrixCSR{Tv,Ti}},Adjoint{Tv,<:AbstractGPUSparseMatrixCSR{Tv,Ti}}}
const AnyGPUSparseMatrixCOO{Tv,Ti} = Union{AbstractGPUSparseMatrixCOO{Tv,Ti},Transpose{Tv,<:AbstractGPUSparseMatrixCOO{Tv,Ti}},Adjoint{Tv,<:AbstractGPUSparseMatrixCOO{Tv,Ti}}}

## broadcasting

Expand Down Expand Up @@ -157,9 +176,9 @@ end
# this problem will be introduced in https://github.com/JuliaLang/julia/pull/39217
macro __tryfinally(ex, fin)
Expr(:tryfinally,
:($(esc(ex))),
:($(esc(fin)))
)
:($(esc(ex))),
:($(esc(fin)))
)
end

"""
Expand All @@ -182,7 +201,7 @@ end
function allowscalar(allow::Bool=true)
if allow
@warn """It's not recommended to use allowscalar([true]) to allow scalar indexing.
Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.""" maxlog=1
Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.""" maxlog = 1
end
setting = allow ? ScalarAllowed : ScalarDisallowed
task_local_storage(:ScalarIndexing, setting)
Expand All @@ -204,8 +223,8 @@ macro allowscalar(ex)
local tls_value = get(task_local_storage(), :ScalarIndexing, nothing)
task_local_storage(:ScalarIndexing, ScalarAllowed)
@__tryfinally($(esc(ex)),
isnothing(tls_value) ? delete!(task_local_storage(), :ScalarIndexing)
: task_local_storage(:ScalarIndexing, tls_value))
isnothing(tls_value) ? delete!(task_local_storage(), :ScalarIndexing)
: task_local_storage(:ScalarIndexing, tls_value))
end
end

Expand Down
2 changes: 2 additions & 0 deletions lib/JLArrays/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
Adapt = "2.0, 3.0, 4.0"
GPUArrays = "11.1"
KernelAbstractions = "0.9"
Random = "1"
SparseArrays = "1"
julia = "1.8"
4 changes: 4 additions & 0 deletions lib/JLArrays/src/JLArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ export JLArray, JLVector, JLMatrix, jl, JLBackend
using GPUArrays

using Adapt
using SparseArrays
using SparseArrays: getcolptr, getrowval, getnzval, nonzeroinds

import KernelAbstractions
import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
Expand Down Expand Up @@ -387,4 +389,6 @@ Adapt.adapt_storage(::JLBackend, a::Array) = Adapt.adapt(JLArrays.JLArray, a)
Adapt.adapt_storage(::JLBackend, a::JLArrays.JLArray) = a
Adapt.adapt_storage(::KernelAbstractions.CPU, a::JLArrays.JLArray) = convert(Array, a)

include("sparse.jl")

end
95 changes: 95 additions & 0 deletions lib/JLArrays/src/sparse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
export JLSparseVector, JLSparseMatrixCSC

## Sparse Vector

struct JLSparseVector{Tv,Ti<:Integer} <: AbstractGPUSparseVector{Tv,Ti}
n::Ti # Length of the sparse vector
nzind::JLVector{Ti} # Indices of stored values
nzval::JLVector{Tv} # Stored values, typically nonzeros

function JLSparseVector{Tv,Ti}(n::Integer, nzind::JLVector{Ti}, nzval::JLVector{Tv}) where {Tv,Ti<:Integer}
n >= 0 || throw(ArgumentError("The number of elements must be non-negative."))
length(nzind) == length(nzval) ||
throw(ArgumentError("index and value vectors must be the same length"))
new(convert(Ti, n), nzind, nzval)
end
end

JLSparseVector(n::Integer, nzind::JLVector{Ti}, nzval::JLVector{Tv}) where {Tv,Ti} =
JLSparseVector{Tv,Ti}(n, nzind, nzval)

JLSparseVector(V::SparseVector) = JLSparseVector(V.n, JLVector(V.nzind), JLVector(V.nzval))
SparseVector(V::JLSparseVector) = SparseVector(V.n, Vector(V.nzind), Vector(V.nzval))

Base.copy(V::JLSparseVector) = JLSparseVector(V.n, copy(V.nzind), copy(V.nzval))

Base.length(V::JLSparseVector) = V.n
Base.size(V::JLSparseVector) = (V.n,)

SparseArrays.nonzeros(V::JLSparseVector) = V.nzval
SparseArrays.nonzeroinds(V::JLSparseVector) = V.nzind

## SparseMatrixCSC

struct JLSparseMatrixCSC{Tv,Ti<:Integer} <: AbstractGPUSparseMatrixCSC{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::JLVector{Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
rowval::JLVector{Ti} # Row indices of stored values
nzval::JLVector{Tv} # Stored values, typically nonzeros

function JLSparseMatrixCSC{Tv,Ti}(m::Integer, n::Integer, colptr::JLVector{Ti},
rowval::JLVector{Ti}, nzval::JLVector{Tv}) where {Tv,Ti<:Integer}
SparseArrays.sparse_check_Ti(m, n, Ti)
GPUArrays._goodbuffers_csc(m, n, colptr, rowval, nzval) ||
throw(ArgumentError("Invalid buffers for JLSparseMatrixCSC construction n=$n, colptr=$(summary(colptr)), rowval=$(summary(rowval)), nzval=$(summary(nzval))"))
new(Int(m), Int(n), colptr, rowval, nzval)
end
end
function JLSparseMatrixCSC(m::Integer, n::Integer, colptr::JLVector, rowval::JLVector, nzval::JLVector)
Tv = eltype(nzval)
Ti = promote_type(eltype(colptr), eltype(rowval))
SparseArrays.sparse_check_Ti(m, n, Ti)
# SparseArrays.sparse_check(n, colptr, rowval, nzval) # TODO: this uses scalar indexing
# silently shorten rowval and nzval to usable index positions.
maxlen = abs(widemul(m, n))
isbitstype(Ti) && (maxlen = min(maxlen, typemax(Ti) - 1))
length(rowval) > maxlen && resize!(rowval, maxlen)
length(nzval) > maxlen && resize!(nzval, maxlen)
JLSparseMatrixCSC{Tv,Ti}(m, n, colptr, rowval, nzval)
end

JLSparseMatrixCSC(A::SparseMatrixCSC) = JLSparseMatrixCSC(A.m, A.n, JLVector(A.colptr), JLVector(A.rowval), JLVector(A.nzval))
SparseMatrixCSC(A::JLSparseMatrixCSC) = SparseMatrixCSC(A.m, A.n, Vector(A.colptr), Vector(A.rowval), Vector(A.nzval))

Base.copy(A::JLSparseMatrixCSC) = JLSparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), copy(A.nzval))

Base.size(A::JLSparseMatrixCSC) = (A.m, A.n)
Base.length(A::JLSparseMatrixCSC) = A.m * A.n

SparseArrays.nonzeros(A::JLSparseMatrixCSC) = A.nzval
SparseArrays.getcolptr(A::JLSparseMatrixCSC) = A.colptr
SparseArrays.rowvals(A::JLSparseMatrixCSC) = A.rowval

## Device

function Adapt.adapt_structure(to, A::JLSparseMatrixCSC)
m = A.m
n = A.n
colptr = Adapt.adapt(to, getcolptr(A))
rowval = Adapt.adapt(to, rowvals(A))
nzval = Adapt.adapt(to, nonzeros(A))
return JLSparseDeviceMatrixCSC(m, n, colptr, rowval, nzval)
end

struct JLSparseDeviceMatrixCSC{Tv,Ti} <: AbstractGPUSparseMatrixCSC{Tv,Ti}
m::Int
n::Int
colptr::JLDeviceArray{Ti,1}
rowval::JLDeviceArray{Ti,1}
nzval::JLDeviceArray{Tv,1}
end

SparseArrays.nonzeros(A::JLSparseDeviceMatrixCSC) = A.nzval
SparseArrays.getcolptr(A::JLSparseDeviceMatrixCSC) = A.colptr
SparseArrays.rowvals(A::JLSparseDeviceMatrixCSC) = A.rowval
6 changes: 4 additions & 2 deletions src/GPUArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ using KernelAbstractions
using Serialization
using Random
using LinearAlgebra
using SparseArrays
using SparseArrays: getcolptr, getrowval, getnzval, nonzeroinds

using Printf

using LinearAlgebra.BLAS
Expand All @@ -15,8 +18,6 @@ using LLVM.Interop
using Reexport
@reexport using GPUArraysCore

using KernelAbstractions

# device functionality
include("device/abstractarray.jl")

Expand All @@ -33,6 +34,7 @@ include("host/math.jl")
include("host/random.jl")
include("host/quirks.jl")
include("host/uniformscaling.jl")
include("host/sparse.jl")
include("host/statistics.jl")


Expand Down
Loading