Skip to content

Commit

Permalink
linalg: Support more inputs to tril! and triu!; printarray and getpro…
Browse files Browse the repository at this point in the history
…perty of QR.
  • Loading branch information
evelyne-ringoot authored and maleadt committed Jul 5, 2024
1 parent 40fa8c0 commit ee6ebe2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
3 changes: 2 additions & 1 deletion lib/GPUArraysCore/src/GPUArraysCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Adapt

export AbstractGPUArray, AbstractGPUVector, AbstractGPUMatrix, AbstractGPUVecOrMat,
WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle,
AnyGPUArray, AnyGPUVector, AnyGPUMatrix
AnyGPUArray, AnyGPUVector, AnyGPUMatrix, AnyGPUVecOrMat

"""
AbstractGPUArray{T, N} <: DenseArray{T, N}
Expand All @@ -27,6 +27,7 @@ const WrappedGPUArray{T,N} = WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{
const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}}
const AnyGPUVector{T} = AnyGPUArray{T, 1}
const AnyGPUMatrix{T} = AnyGPUArray{T, 2}
const AnyGPUVecOrMat{T} = Union{AnyGPUArray{T, 1}, AnyGPUArray{T, 2}}


## broadcasting
Expand Down
22 changes: 20 additions & 2 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriang
@eval Base.copyto!(A::$T{T, <:AbstractGPUArray{T,N}}, B::$T{T, <:AbstractGPUArray{T,N}}) where {T,N} = $T(copyto!(parent(A), parent(B)))
end

function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
function LinearAlgebra.tril!(A::AnyGPUMatrix{T}, d::Integer = 0) where T
gpu_call(A, d; name="tril!") do ctx, _A, _d
I = @cartesianidx _A
i, j = Tuple(I)
Expand All @@ -182,7 +182,7 @@ function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
return A
end

function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
function LinearAlgebra.triu!(A::AnyGPUMatrix{T}, d::Integer = 0) where T
gpu_call(A, d; name="triu!") do ctx, _A, _d
I = @cartesianidx _A
i, j = Tuple(I)
Expand Down Expand Up @@ -795,3 +795,21 @@ function Base.isone(x::AbstractGPUMatrix{T}) where {T}

Array(y)[]
end

## QR

import LinearAlgebra: QRPackedQ

function LinearAlgebra.getproperty(F::QR{T,<:AnyGPUMatrix{T}}, d::Symbol) where {T}
m, n = size(F)
if d === :R
return triu!(view(getfield(F, :factors), 1:min(m,n), 1:n))
elseif d === :Q
return LinearAlgebra.QRPackedQ(getfield(F, :factors), F.τ)
else
getfield(F, d)
end
end

Base.print_array(io::IO, Q::QRPackedQ{T,<:AnyGPUMatrix{T},<:AnyGPUMatrix{T}}) where {T} =
Base.print_array(io, collect(adapt(ToArray(), Q)))
7 changes: 7 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,10 @@ end
@test isrealfloattype(typeof(opnorm(AT(mat), p)))
end
end

@testsuite "QR" (AT, eltypes)->begin
@testset "get property" for dims in [(3,5),(3,3),(5,3)],
prop in [:Q, :R], T in eltypes
@test compare(x -> getproperty(qr(x), prop), AT, rand(T, dims))
end
end

0 comments on commit ee6ebe2

Please sign in to comment.