Skip to content

Commit

Permalink
Formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Dec 17, 2024
1 parent a939219 commit 611de00
Showing 1 changed file with 40 additions and 40 deletions.
80 changes: 40 additions & 40 deletions test/basics/test_svd.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,51 @@
using Test
using BlockSparseArrays
using BlockSparseArrays:
BlockSparseArray, svd, notrunc, truncbelow, truncdim, BlockDiagonal
using BlockSparseArrays: BlockSparseArray, svd, notrunc, truncbelow, truncdim, BlockDiagonal
using BlockArrays
using LinearAlgebra: LinearAlgebra, Diagonal, svdvals
using Random

function test_svd(a, usv)
U, S, V = usv
U, S, V = usv

@test U * Diagonal(S) * V' a
@test U' * U LinearAlgebra.I
@test V' * V LinearAlgebra.I
@test U * Diagonal(S) * V' a
@test U' * U LinearAlgebra.I
@test V' * V LinearAlgebra.I
end

# regular matrix
# --------------
sizes = ((3, 3), (4, 3), (3, 4))
eltypes = (Float32, Float64, ComplexF64)
@testset "($m, $n) Matrix{$T}" for ((m, n), T) in Iterators.product(sizes, eltypes)
a = rand(m, n)
usv = @inferred svd(a)
test_svd(a, usv)
a = rand(m, n)
usv = @inferred svd(a)
test_svd(a, usv)
end

# block matrix
# ------------
blockszs = (([2, 2], [2, 2]), ([2, 2], [2, 3]), ([2, 2, 1], [2, 3]), ([2, 3], [2]))
@testset "($m, $n) BlockMatrix{$T}" for ((m, n), T) in Iterators.product(blockszs, eltypes)
a = mortar([rand(T, i, j) for i in m, j in n])
usv = svd(a)
test_svd(a, usv)
@test usv.U isa BlockedMatrix
@test usv.Vt isa BlockedMatrix
@test usv.S isa BlockedVector
a = mortar([rand(T, i, j) for i in m, j in n])
usv = svd(a)
test_svd(a, usv)
@test usv.U isa BlockedMatrix
@test usv.Vt isa BlockedMatrix
@test usv.S isa BlockedVector
end

# Block-Diagonal matrices
# -----------------------
@testset "($m, $n) BlockDiagonal{$T}" for ((m, n), T) in
Iterators.product(blockszs, eltypes)
a = BlockDiagonal([rand(T, i, j) for (i, j) in zip(m, n)])
usv = svd(a)
# TODO: `BlockDiagonal * Adjoint` errors
test_svd(a, usv)
@test usv.U isa BlockDiagonal
@test usv.Vt isa BlockDiagonal
@test usv.S isa BlockVector
a = BlockDiagonal([rand(T, i, j) for (i, j) in zip(m, n)])
usv = svd(a)
# TODO: `BlockDiagonal * Adjoint` errors
test_svd(a, usv)
@test usv.U isa BlockDiagonal
@test usv.Vt isa BlockDiagonal
@test usv.S isa BlockVector
end

a = mortar([rand(2, 2) for i in 1:2, j in 1:3])
Expand All @@ -60,24 +60,24 @@ test_svd(a, usv)
# -----------
@testset "($m, $n) BlockDiagonal{$T}" for ((m, n), T) in
Iterators.product(blockszs, eltypes)
a = BlockSparseArray{T}(m, n)
for i in LinearAlgebra.diagind(blocks(a))
I = CartesianIndices(blocks(a))[i]
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
end
perm = Random.randperm(length(m))
a = a[Block.(perm), Block.(1:length(n))]
a = BlockSparseArray{T}(m, n)
for i in LinearAlgebra.diagind(blocks(a))
I = CartesianIndices(blocks(a))[i]
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
end
perm = Random.randperm(length(m))
a = a[Block.(perm), Block.(1:length(n))]

# errors because `blocks(a)[CartesianIndex.(...)]` is not implemented
usv = svd(a)
# TODO: `BlockDiagonal * Adjoint` errors
test_svd(a, usv)
@test usv.U isa BlockDiagonal
@test usv.Vt isa BlockDiagonal
@test usv.S isa BlockVector
# errors because `blocks(a)[CartesianIndex.(...)]` is not implemented
usv = svd(a)
# TODO: `BlockDiagonal * Adjoint` errors
test_svd(a, usv)
@test usv.U isa BlockDiagonal
@test usv.Vt isa BlockDiagonal
@test usv.S isa BlockVector

test_svd(a, usv2)
@test usv.U isa BlockDiagonal
@test usv.Vt isa BlockDiagonal
@test usv.S isa BlockVector
test_svd(a, usv2)
@test usv.U isa BlockDiagonal
@test usv.Vt isa BlockDiagonal
@test usv.S isa BlockVector
end

0 comments on commit 611de00

Please sign in to comment.