Skip to content

Commit

Permalink
[NDTensorsMetalExt] Update for latest Unwrap
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 9, 2023
1 parent 9556f9d commit e1be8a6
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 24 deletions.
3 changes: 2 additions & 1 deletion NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module NDTensorsMetalExt

using Adapt
using Functors
using LinearAlgebra: LinearAlgebra, Transpose, mul!, qr, eigen, svd
using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, mul!, qr, eigen, svd
using NDTensors
using NDTensors.SetParameters
using NDTensors.Unwrap: qr_positive, ql_positive, ql
Expand All @@ -22,4 +22,5 @@ include("copyto.jl")
include("append.jl")
include("permutedims.jl")
include("mul.jl")

end
2 changes: 2 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/adapt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
NDTensors.cpu(e::Exposed{<:MtlArray}) = adapt(Array, e)

function mtl(xs; storage=DefaultStorageMode)
return adapt(set_storagemode(MtlArray, storage), xs)
end
Expand Down
24 changes: 17 additions & 7 deletions NDTensors/ext/NDTensorsMetalExt/copyto.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
# Catches a bug in `copyto!` in Metal backend.
function NDTensors.copyto!(
::Type{<:MtlArray}, dest::AbstractArray, ::Type{<:MtlArray}, src::SubArray
function Base.copy(src::Exposed{<:MtlArray,<:Base.ReshapedArray})
return reshape(copy(parent(src)), size(unexpose(src)))
end

function Base.copy(
src::Exposed{
<:MtlArray,<:SubArray{<:Any,<:Any,<:Base.ReshapedArray{<:Any,<:Any,<:Adjoint}}
},
)
return Base.copyto!(dest, copy(src))
return copy(@view copy(expose(parent(src)))[parentindices(unexpose(src))...])
end

# Catches a bug in `copyto!` in Metal backend.
function Base.copyto!(dest::Exposed{<:MtlArray}, src::Exposed{<:MtlArray,<:SubArray})
return copyto!(dest, expose(copy(src)))
end

# Catches a bug in `copyto!` in Metal backend.
function NDTensors.copyto!(
::Type{<:MtlArray}, dest::AbstractArray, ::Type{<:MtlArray}, src::Base.ReshapedArray
function Base.copyto!(
dest::Exposed{<:MtlArray}, src::Exposed{<:MtlArray,<:Base.ReshapedArray}
)
return NDTensors.copyto!(dest, parent(src))
return copyto!(dest, expose(parent(src)))
end
5 changes: 5 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@ function Base.setindex!(E::Exposed{<:MtlArray}, x::Number)
Metal.@allowscalar unexpose(E)[] = x
return unexpose(E)
end

# Shared with `CuArray`. Move to `NDTensorsGPUArraysCoreExt`?
function Base.getindex(E::Exposed{<:MtlArray,<:Adjoint}, i, j)
return (expose(parent(E))[j, i])'
end
9 changes: 5 additions & 4 deletions NDTensors/ext/NDTensorsMetalExt/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ function LinearAlgebra.eigen(A::Exposed{<:MtlMatrix})
end

function LinearAlgebra.svd(A::Exposed{<:MtlMatrix}; kwargs...)
U, S, V = svd(expose(NDTensors.cpu(A)); kwargs...)
return adapt(unwrap_type(A), U),
adapt(set_ndims(unwrap_type(A), ndims(S)), S),
adapt(unwrap_type(A), V)
Ucpu, Scpu, Vcpu = svd(expose(NDTensors.cpu(A)); kwargs...)
U = adapt(unwrap_type(A), Ucpu)
S = adapt(set_ndims(unwrap_type(A), ndims(Scpu)), Scpu)
V = adapt(unwrap_type(A), Vcpu)
return U, S, V
end
2 changes: 1 addition & 1 deletion NDTensors/ext/NDTensorsMetalExt/permutedims.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function permutedims!(
function Base.permutedims!(
Edest::Exposed{<:MtlArray,<:Base.ReshapedArray}, Esrc::Exposed{<:MtlArray}, perm
)
Aperm = permutedims(Esrc, perm)
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/Unwrap/src/Unwrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using SimpleTraits
using LinearAlgebra
using Base: ReshapedArray
using StridedViews
using Adapt: Adapt, adapt, adapt_structure

include("expose.jl")
include("iswrappedarray.jl")
Expand All @@ -16,6 +17,7 @@ include("functions/copyto.jl")
include("functions/linearalgebra.jl")
include("functions/mul.jl")
include("functions/permutedims.jl")
include("functions/adapt.jl")

export IsWrappedArray,
is_wrapped_array, parenttype, unwrap_type, expose, Exposed, unexpose, cpu
Expand Down
8 changes: 8 additions & 0 deletions NDTensors/src/Unwrap/src/functions/adapt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Adapt.adapt(to, x::Exposed) = adapt_structure(to, x)
Adapt.adapt_structure(to, x::Exposed) = adapt_structure(to, unexpose(x))

# https://github.com/JuliaGPU/Adapt.jl/pull/51
# TODO: Remove once https://github.com/JuliaGPU/Adapt.jl/issues/71 is addressed.
function Adapt.adapt_structure(to, A::Exposed{<:Any,<:Hermitian})
return Hermitian(adapt(to, parent(unexpose(A))), Symbol(unexpose(A).uplo))
end
54 changes: 45 additions & 9 deletions NDTensors/src/Unwrap/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ using LinearAlgebra

include("../../../test/device_list.jl")
@testset "Testing Unwrap" for dev in devices_list(ARGS)
v = dev(Vector{Float64}(undef, 10))
elt = Float32

v = dev(Vector{elt}(undef, 10))
vt = transpose(v)
va = v'

Expand Down Expand Up @@ -37,7 +39,7 @@ include("../../../test/device_list.jl")
@test typeof(Et) == Exposed{m_type,LinearAlgebra.Transpose{e_type,m_type}}
@test typeof(Ea) == Exposed{m_type,LinearAlgebra.Adjoint{e_type,m_type}}

o = dev(Vector{Float32})(undef, 1)
o = dev(Vector{elt})(undef, 1)
expose(o)[] = 2
@test expose(o)[] == 2

Expand All @@ -56,8 +58,8 @@ include("../../../test/device_list.jl")
q, r = Unwrap.qr_positive(expose(mp))
@test q * r mp

square = dev(rand(Float64, (10, 10)))
square = (square + transpose(square)) ./ 2.0
square = dev(rand(elt, (10, 10)))
square = (square + transpose(square)) / 2
## CUDA only supports Hermitian or Symmetric eigen decompositions
## So I symmetrize square and call symetric here
l, U = eigen(expose(Symmetric(square)))
Expand All @@ -66,25 +68,59 @@ include("../../../test/device_list.jl")
U, S, V, = svd(expose(mp))
@test U * Diagonal(S) * V' mp

cm = dev(fill!(Matrix{Float64}(undef, (2, 2)), 0.0))
cm = dev(fill!(Matrix{elt}(undef, (2, 2)), 0.0))
mul!(expose(cm), expose(mp), expose(mp'), 1.0, 0.0)
@test cm mp * mp'

@test permutedims(expose(mp), (2, 1)) == transpose(mp)
fill!(mt, 3.0)
fill!(mt, 3)
permutedims!(expose(m), expose(mt), (2, 1))
@test norm(m) == sqrt(3^2 * 10)
@test norm(m) sqrt(3^2 * 10)
@test size(m) == (5, 2)
permutedims!(expose(m), expose(mt), (2, 1), +)
@test size(m) == (5, 2)
@test norm(m) == sqrt(6^2 * 10)
@test norm(m) sqrt(6^2 * 10)

m = reshape(m, (5, 2, 1))
mt = fill!(similar(m), 3.0)
m = permutedims(expose(m), (2, 1, 3))
@test size(m) == (2, 5, 1)
permutedims!(expose(m), expose(mt), (2, 1, 3))
@test norm(m) == sqrt(3^2 * 10)
@test norm(m) sqrt(3^2 * 10)
permutedims!(expose(m), expose(mt), (2, 1, 3), -)
@test norm(m) == 0

x = dev(rand(elt, 4, 4))
y = dev(rand(elt, 4, 4))
copyto!(expose(y), expose(x))
@test y == x

y = dev(rand(elt, 4, 4))
x = Base.ReshapedArray(dev(rand(elt, 16)), (4, 4), ())
copyto!(expose(y), expose(x))
@test NDTensors.cpu(y) == NDTensors.cpu(x)
@test NDTensors.cpu(copy(expose(x))) == NDTensors.cpu(x)

y = dev(rand(elt, 4, 4))
x = @view dev(rand(elt, 8, 8))[1:4, 1:4]
copyto!(expose(y), expose(x))
@test y == x
@test copy(x) == x

y = dev(randn(elt, 16))
x = reshape(dev(randn(elt, 4, 4))', 16)
copyto!(expose(y), expose(x))
@test y == x
@test copy(x) == x

y = dev(randn(elt, 8))
x = @view reshape(dev(randn(elt, 8, 8))', 64)[1:8]
copyto!(expose(y), expose(x))
@test y == x
@test copy(x) == x

y = Base.ReshapedArray(dev(randn(elt, 16)), (4, 4), ())
x = dev(randn(elt, 4, 4))
permutedims!(expose(y), expose(x), (2, 1))
@test NDTensors.cpu(y) == transpose(NDTensors.cpu(x))
end
1 change: 1 addition & 0 deletions NDTensors/src/abstractarray/append.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# NDTensors.append!
# Used to circumvent issues with some GPU backends like Metal
# not supporting `resize!`.
# TODO: Change this over to use `expose`.
function append!!(collection, collections...)
return append!!(unwrap_type(collection), collection, collections...)
end
Expand Down
4 changes: 2 additions & 2 deletions NDTensors/src/blocksparse/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ function svd(
if sU == -1
Ub *= -1
end
copyto!(blockview(U, blockU), Ub)
copyto!(expose(blockview(U, blockU)), expose(Ub))

blockviewS = blockview(S, blockS)
# TODO: Replace `data` with `diagview`.
Expand All @@ -200,7 +200,7 @@ function svd(
if (sV * sVP) == -1
Vb *= -1
end
copyto!(blockview(V, blockV), Vb)
copyto!(expose(blockview(V, blockV)), expose(Vb))
end
return U, S, V, Spectrum(d, truncerr)
end
Expand Down

0 comments on commit e1be8a6

Please sign in to comment.