Skip to content

Commit

Permalink
Add more IsometricKroneckerProduct tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Oct 28, 2023
1 parent d0c7410 commit 53f55e7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@ for TC in [:AbstractVector, :AbstractMatrix]
A::IKP{T},
B::$TC{T},
) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C, A, B)
@eval _matmul!(C::$TC, A::$TC, B::IKP) = _matmul!(C', B', A')
@eval _matmul!(C::$TC, A::$TC, B::IKP) = _matmul!(C', B', A')'
@eval _matmul!(
C::$TC{T},
A::$TC{T},
B::IKP{T},
) where {T<:LinearAlgebra.BlasFloat} = _matmul!(C', B', A')
) where {T<:LinearAlgebra.BlasFloat} = _matmul!(C', B', A')'

@eval _matmul!(C::$TC, A::IKP, B::$TC, alpha::Number, beta::Number) =
mul_vectrick!(C, A, B, alpha, beta)
Expand All @@ -207,12 +207,12 @@ for TC in [:AbstractVector, :AbstractMatrix]
beta::Number,
) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C, A, B, alpha, beta)
@eval _matmul!(C::$TC, A::$TC, B::IKP, alpha::Number, beta::Number) =
mul_vectrick!(C', B', A', alpha, beta)
mul_vectrick!(C', B', A', alpha, beta)'
@eval _matmul!(
C::$TC{T},
A::$TC{T},
B::IKP{T},
alpha::Number,
beta::Number,
) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C', B', A', alpha, beta)
) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C', B', A', alpha, beta)'
end
25 changes: 25 additions & 0 deletions test/core/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ q = 2
@test PNDE._matmul!(K3, K1, K2) PNDE._matmul!(M3, M1, M2)
@test PNDE._matmul!(K3, K1, K2, α, β) PNDE._matmul!(M3, M1, M2, α, β)

# DimensionMismatch
@test_throws DimensionMismatch mul!(X, K1, K2)
@test_throws DimensionMismatch mul!(X, K1, K2, α, β)
@test_throws DimensionMismatch PNDE._matmul!(X, K1, K2)
@test_throws DimensionMismatch PNDE._matmul!(X, K1, K2, α, β)

# Kronecker-trick
v = rand(T, d * (q + 1))
A = rand(T, d * (q + 1), d * (q + 1))
Expand All @@ -96,8 +102,27 @@ q = 2
@test mul!(copy(v), K1, v, α, β) mul!(copy(v), M1, v, α, β)
@test mul!(copy(A), K1, A, α, β) mul!(copy(A), M1, A, α, β)

@test mul!(copy(A'), copy(A'), K1') mul!(copy(A'), A', M1')
@test mul!(copy(A'), copy(A'), K1', α, β) mul!(copy(A'), A', M1', α, β)

@test PNDE._matmul!(copy(v), K1, v) PNDE._matmul!(copy(v), M1, v)
@test PNDE._matmul!(copy(A), K1, A) PNDE._matmul!(copy(A), M1, A)
@test PNDE._matmul!(copy(v), K1, v, α, β) PNDE._matmul!(copy(v), M1, v, α, β)
@test PNDE._matmul!(copy(A), K1, A, α, β) PNDE._matmul!(copy(A), M1, A, α, β)

if T == Float64
# Octavian has issues
@test_broken PNDE._matmul!(copy(A'), copy(A'), K1')
PNDE._matmul!(copy(A'), A', M1')
@test_broken PNDE._matmul!(copy(A'), copy(A'), K1', α, β)
PNDE._matmul!(copy(A'), A', M1', α, β)
else
# Uses LinearAlgebra
@test PNDE._matmul!(copy(A'), copy(A'), K1') PNDE._matmul!(copy(A'), A', M1')
@test PNDE._matmul!(copy(A'), copy(A'), K1', α, β)
PNDE._matmul!(copy(A'), A', M1', α, β)
end
# But it always works if all matrices are actual adjoints
@test PNDE._matmul!(copy(A)', A', K1') PNDE._matmul!(copy(A'), A', M1')
@test PNDE._matmul!(copy(A)', A', K1', α, β) PNDE._matmul!(copy(A'), A', M1', α, β)
end

0 comments on commit 53f55e7

Please sign in to comment.