Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Oct 23, 2023
1 parent c133f40 commit d60b852
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 9 deletions.
4 changes: 3 additions & 1 deletion NDTensors/ext/NDTensorsMetalExt/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ end

function NDTensors.svd(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
U, S, V = NDTensors.svd(NDTensors.cpu(A))
return adapt(leaf_parenttype, U), adapt(set_ndims(leaf_parenttype, ndims(S)), S), adapt(leaf_parenttype, V)
return adapt(leaf_parenttype, U),
adapt(set_ndims(leaf_parenttype, ndims(S)), S),
adapt(leaf_parenttype, V)
end
8 changes: 7 additions & 1 deletion NDTensors/ext/NDTensorsMetalExt/permutedims.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
function NDTensors.permutedims!(::Type{<:MtlArray}, Adest::Base.ReshapedArray{<:Any,<:Any,<:SubArray}, ::Type{<:MtlArray}, A, perm)
function NDTensors.permutedims!(
::Type{<:MtlArray},
Adest::Base.ReshapedArray{<:Any,<:Any,<:SubArray},
::Type{<:MtlArray},
A,
perm,
)
Aperm = permutedims(A, perm)
Adest_parent = parent(Adest)
copyto!(Adest_parent, Aperm)
Expand Down
4 changes: 3 additions & 1 deletion NDTensors/src/abstractarray/permutedims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ function permutedims!!(B::AbstractArray, A::AbstractArray, perm)
return permutedims!!(leaf_parenttype(B), B, leaf_parenttype(A), A, perm)
end

function permutedims!!(Bleaftype::Type{<:AbstractArray}, B, Aleaftype::Type{<:AbstractArray}, A, perm)
function permutedims!!(
Bleaftype::Type{<:AbstractArray}, B, Aleaftype::Type{<:AbstractArray}, A, perm
)
permutedims!(leaf_parenttype(B), B, leaf_parenttype(A), A, perm)
return B
end
Expand Down
6 changes: 5 additions & 1 deletion NDTensors/src/blocksparse/blocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ function BlockSparse(
end

function BlockSparse(
datatype::Type{<:AbstractArray}, ::UndefInitializer, blockoffsets::BlockOffsets, dim::Integer; vargs...
datatype::Type{<:AbstractArray},
::UndefInitializer,
blockoffsets::BlockOffsets,
dim::Integer;
vargs...,
)
return BlockSparse(datatype(undef, dim), blockoffsets; vargs...)
end
Expand Down
5 changes: 4 additions & 1 deletion NDTensors/src/blocksparse/blocksparsetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ function BlockSparseTensor(
end

function BlockSparseTensor(
datatype::Type{<:AbstractArray}, ::UndefInitializer, blocks::Vector{<:Union{Block,NTuple}}, inds
datatype::Type{<:AbstractArray},
::UndefInitializer,
blocks::Vector{<:Union{Block,NTuple}},
inds,
)
boffs, nnz = blockoffsets(blocks, inds)
storage = BlockSparse(datatype, undef, boffs, nnz)
Expand Down
5 changes: 4 additions & 1 deletion NDTensors/src/blocksparse/diagblocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ function DiagBlockSparse(
end

function DiagBlockSparse(
datatype::Type{<:AbstractArray}, ::UndefInitializer, boffs::BlockOffsets, diaglength::Integer
datatype::Type{<:AbstractArray},
::UndefInitializer,
boffs::BlockOffsets,
diaglength::Integer,
)
return DiagBlockSparse(datatype(undef, diaglength), boffs)
end
Expand Down
8 changes: 6 additions & 2 deletions NDTensors/src/blocksparse/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ function svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT}
indsS = setindex(indsS, dag(vind), 2)

U = BlockSparseTensor(leaf_parenttype(T), undef, nzblocksU, indsU)
S = DiagBlockSparseTensor(set_eltype(leaf_parenttype(T), real(ElT)), undef, nzblocksS, indsS)
S = DiagBlockSparseTensor(
set_eltype(leaf_parenttype(T), real(ElT)), undef, nzblocksS, indsS
)
V = BlockSparseTensor(leaf_parenttype(T), undef, nzblocksV, indsV)

for n in 1:nnzblocksT
Expand Down Expand Up @@ -302,7 +304,9 @@ function eigen(
nzblocksV[n] = blockV
end

D = DiagBlockSparseTensor(set_ndims(set_eltype(leaf_parenttype(T), ElD), 1), undef, nzblocksD, indsD)
D = DiagBlockSparseTensor(
set_ndims(set_eltype(leaf_parenttype(T), ElD), 1), undef, nzblocksD, indsD
)
V = BlockSparseTensor(set_eltype(leaf_parenttype(T), ElV), undef, nzblocksV, indsV)

for n in 1:nnzblocksT
Expand Down
4 changes: 3 additions & 1 deletion NDTensors/src/dense/densetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ end
return T
end

@propagate_inbounds function setindex!(::Type{<:AbstractArray}, T::DenseTensor{<:Number}, x::Number)
@propagate_inbounds function setindex!(
::Type{<:AbstractArray}, T::DenseTensor{<:Number}, x::Number
)
data(T)[] = x
return T
end
Expand Down

0 comments on commit d60b852

Please sign in to comment.