Skip to content

Commit

Permalink
Add unit-tests for our BlockDiags
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Feb 15, 2024
1 parent aafb416 commit 6501878
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 19 deletions.
2 changes: 1 addition & 1 deletion ext/BlockDiagonalsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module BlockDiagonalsExt

import ProbNumDiffEq: ProbNumDiffEqBlockDiagonal, blocks
using BlockDiagonals
import BlockDiagonals: BlockDiagonal

BlockDiagonal(M::ProbNumDiffEqBlockDiagonal) = BlockDiagonal(blocks(M))

Expand Down
2 changes: 1 addition & 1 deletion src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ __precompile__()

module ProbNumDiffEq

import Base: copy, copy!, show, size, ndims, similar, isapprox, isequal, iterate, ==, length
import Base: copy, copy!, show, size, ndims, similar, isapprox, isequal, iterate, ==, length, zero

using LinearAlgebra
import LinearAlgebra: mul!
Expand Down
44 changes: 27 additions & 17 deletions src/blockdiagonals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,28 @@ copy!(B::BlockDiag, A::BlockDiag) = begin
end
return B
end
similar(B::BlockDiag) = BlockDiag(similar.(blocks(B)))
zero(B::BlockDiag) = BlockDiag(zero.(blocks(B)))

# Mul with Scalar or UniformScaling
Base.:*(a::Number, M::BlockDiag) = BlockDiag([a * B for B in blocks(M)])
Base.:*(M::BlockDiag, a::Number) = BlockDiag([B * a for B in blocks(M)])
Base.:*(U::UniformScaling, M::BlockDiag) = BlockDiag([U * B for B in blocks(M)])
Base.:*(M::BlockDiag, U::UniformScaling) = BlockDiag([B * U for B in blocks(M)])

# Mul between BockDiag's
Base.:*(A::BlockDiag, B::BlockDiag) = begin
@assert length(A.blocks) == length(B.blocks)
return BlockDiag([Ai * Bi for (Ai, Bi) in zip(blocks(A), blocks(B))])
end
Base.:*(A::Adjoint{T,<:BlockDiag}, B::BlockDiag) where {T} = begin
@assert length(A.parent.blocks) == length(B.blocks)
return BlockDiag([Ai' * Bi for (Ai, Bi) in zip(blocks(A.parent), blocks(B))])
end
Base.:*(A::BlockDiag, B::Adjoint{T,<:BlockDiag}) where {T} = begin
@assert length(A.blocks) == length(B.parent.blocks)
return BlockDiag([Ai * Bi' for (Ai, Bi) in zip(blocks(A), blocks(B.parent))])
end

# Standard LinearAlgebra.mul!
mul!(C::BlockDiag, A::BlockDiag, B::BlockDiag) = begin
Expand Down Expand Up @@ -156,21 +178,9 @@ _matmul!(
return C
end

LinearAlgebra.rmul!(B::BlockDiag, n::Number) = @simd ivdep for i in eachindex(B.blocks)
rmul!(B.blocks[i], n)
end
LinearAlgebra.adjoint(B::BlockDiag) = Adjoint(B)

Base.:*(A::BlockDiag, B::BlockDiag) = begin
@assert length(A.blocks) == length(B.blocks)
return BlockDiag([blocks(A)[i] * blocks(B)[i] for i in eachindex(B.blocks)])
end
Base.:*(A::Adjoint{T,<:BlockDiag}, B::BlockDiag) where {T} = begin
@assert length(A.parent.blocks) == length(B.blocks)
return BlockDiag([A.parent.blocks[i]' * B.blocks[i] for i in eachindex(B.blocks)])
end
Base.:*(A::BlockDiag, B::Adjoint{T,<:BlockDiag}) where {T} = begin
@assert length(A.blocks) == length(B.parent.blocks)
return BlockDiag([A.blocks[i] * B.parent.blocks[i]' for i in eachindex(B.parent.blocks)])
LinearAlgebra.rmul!(B::BlockDiag, n::Number) = begin
@simd ivdep for i in eachindex(B.blocks)
rmul!(B.blocks[i], n)
end
return B
end
Base.:*(A::UniformScaling, B::BlockDiag) = BlockDiag([A * blocks(B)[i] for i in eachindex(B.blocks)])
57 changes: 57 additions & 0 deletions test/core/blockdiagonals.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using ProbNumDiffEq
import ProbNumDiffEq: BlockDiag, _matmul!
using LinearAlgebra
using BlockDiagonals
using Test

d1, d2 = 2, 3
@testset "T=$T" for T in (Float64, BigFloat)
A = BlockDiag([randn(T, d1, d1) for _ in 1:d2])
B = BlockDiag([randn(T, d1, d1) for _ in 1:d2])
C = BlockDiag([randn(T, d1, d1) for _ in 1:d2])

AM, BM, CM = @test_nowarn Matrix.((A, B, C))

@test Matrix(BlockDiagonal(A)) == AM
@test Matrix(BlockDiagonal(B)) == BM
@test Matrix(BlockDiagonal(C)) == CM

_A = @test_nowarn copy(A)
@test _A isa BlockDiag

_B = @test_nowarn copy!(_A, B)
@test _B === _A
@test _B == B

_A = @test_nowarn similar(A)
@test _A isa BlockDiag
@test size(_A) == size(A)

_Z = @test_nowarn zero(A)
@test _Z isa BlockDiag
@test size(_Z) == size(A)
@test all(_Z .== 0)

function tttm(M) # quick type test and to matrix
@test M isa BlockDiag
return Matrix(M)
end

for _mul! in (:mul!, :_matmul!)
@test @eval tttm($_mul!(C, A, B)) $_mul!(CM, AM, BM)
@test @eval tttm($_mul!(C, A', B)) $_mul!(CM, AM', BM)
@test @eval tttm($_mul!(C, A, B')) $_mul!(CM, AM, BM')
end
@test tttm(A * B) AM * BM
@test tttm(A' * B) AM' * BM
@test tttm(A * B') AM * BM'

a = rand()
@test tttm(A * a) AM * a
@test tttm(a * A) a * AM
@test tttm(A * (a * I)) AM * a
@test tttm((a * I) * A) a * AM
@test tttm(rmul!(copy(A), a)) a * AM

@test_throws ErrorException view(A, 1:2, 1:2)
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ const GROUP = get(ENV, "GROUP", "All")
@testset "ProbNumDiffEq" begin
if GROUP == "All" || GROUP == "Core"
@timedtestset "Core" begin
@timedsafetestset "BlockDiagonals" begin
include("core/blockdiagonals.jl")
end
@timedsafetestset "Filtering" begin
include("core/filtering.jl")
end
Expand Down

0 comments on commit 6501878

Please sign in to comment.