diff --git a/src/fast_linalg.jl b/src/fast_linalg.jl index 82561d999..168d2afec 100644 --- a/src/fast_linalg.jl +++ b/src/fast_linalg.jl @@ -97,19 +97,19 @@ function fast_X_A_Xt!(out::PSDMatrix, A::PSDMatrix, X::AbstractMatrix) return out end -""" - alloc_free_get_U!(C::Cholesky) - -Allocation-free version of `C.U`. - -THIS MODIFIES `C.factors` SO AFTERWARDS `C` SHOULD NOT BE USED ANYMORE! -""" -function alloc_free_get_U!(C::Cholesky) - Cuplo = getfield(C, :uplo) - Cfactors = getfield(C, :factors) - if Cuplo === LinearAlgebra.char_uplo(:U) - return getupperright!(Cfactors) - else - return getupperright!(Cfactors') - end -end +# """ +# alloc_free_get_U!(C::Cholesky) + +# Allocation-free version of `C.U`. + +# THIS MODIFIES `C.factors` SO AFTERWARDS `C` SHOULD NOT BE USED ANYMORE! +# """ +# function alloc_free_get_U!(C::Cholesky) +# Cuplo = getfield(C, :uplo) +# Cfactors = getfield(C, :factors) +# if Cuplo === LinearAlgebra.char_uplo(:U) +# return getupperright!(Cfactors) +# else +# return getupperright!(Cfactors') +# end +# end diff --git a/src/kronecker.jl b/src/kronecker.jl index 57f8645f4..3c99346db 100644 --- a/src/kronecker.jl +++ b/src/kronecker.jl @@ -27,13 +27,26 @@ IsometricKroneckerProduct(ldim::Integer, B::AbstractVector) = IsometricKroneckerProduct(ldim, reshape(B, :, 1)) const IKP = IsometricKroneckerProduct -get_right_factor(K::IKP) = K.B -get_left_factor_dim(K::IKP) = K.B Kronecker.getmatrices(K::IKP) = (I(K.ldim), K.B) Base.zero(A::IKP) = IsometricKroneckerProduct(A.ldim, zero(A.B)) Base.one(A::IKP) = IsometricKroneckerProduct(A.ldim, one(A.B)) +copy!(A::IKP, B::IKP) = begin + check_same_size(A, B) + copy!(A.B, B.B) + return A +end +copy(A::IKP) = IsometricKroneckerProduct(A.ldim, copy(A.B)) +similar(A::IKP) = IsometricKroneckerProduct(A.ldim, similar(A.B)) +Base.size(K::IKP) = (K.ldim * size(K.B, 1), K.ldim * size(K.B, 2)) + +# conversion +Base.convert(::Type{T}, K::IKP) where {T<:IKP} = + K isa T ? K : T(K) +function IKP{T,TB}(K::IKP) where {T,TB} + IKP(K.ldim, convert(TB, K.B)) +end function Base.:*(A::IKP, B::IKP) @assert A.ldim == B.ldim @@ -55,7 +68,7 @@ function check_matmul_sizes(A::IKP, B::IKP) # For A * B Ad, Bd = A.ldim, B.ldim An, Am, Bn, Bm = size(A)..., size(B)... - if !(A.ldim == B.ldim) || !(Am == Bnb) + if !(A.ldim == B.ldim) || !(Am == Bn) throw( DimensionMismatch( "Matrix multiplication not compatible: A has size ($Ad⋅$An,$Ad⋅$Am), B has size ($Bd⋅$Bn,$Bd⋅$Bm)", @@ -93,6 +106,18 @@ Base.:\(A::IKP, B::IKP) = begin return IsometricKroneckerProduct(A.ldim, A.B \ B.B) end +mul!(A::IKP, B::IKP, C::IKP) = begin + check_matmul_sizes(A, B, C) + mul!(A.B, B.B, C.B) + return A +end +mul!(A::IKP, B::IKP, C::IKP, alpha::Number, beta::Number) = begin + check_matmul_sizes(A, B, C) + mul!(A.B, B.B, C.B, alpha, beta) + return A +end + +# fast_linalg.jl _matmul!(A::IKP, B::IKP, C::IKP) = begin check_matmul_sizes(A, B, C) _matmul!(A.B, B.B, C.B) @@ -105,7 +130,7 @@ _matmul!(A::IKP{T}, B::IKP{T}, C::IKP{T}) where {T<:LinearAlgebra.BlasFloat} = b end _matmul!(A::IKP, B::IKP, C::IKP, alpha::Number, beta::Number) = begin check_matmul_sizes(A, B, C) - _matmul!(A.B, B.B, C.B) + _matmul!(A.B, B.B, C.B, alpha, beta) return A end _matmul!( @@ -119,21 +144,6 @@ _matmul!( _matmul!(A.B, B.B, C.B, alpha, beta) return A end -copy!(A::IKP, B::IKP) = begin - check_same_size(A, B) - copy!(A.B, B.B) - return A -end -copy(A::IKP) = IsometricKroneckerProduct(A.ldim, copy(A.B)) -similar(A::IKP) = IsometricKroneckerProduct(A.ldim, similar(A.B)) -Base.size(K::IKP) = (K.ldim * size(K.B, 1), K.ldim * size(K.B, 2)) - -# conversion -Base.convert(::Type{T}, K::IKP) where {T<:IKP} = - K isa T ? K : T(K) -function IKP{T,TB}(K::IKP) where {T,TB} - IKP(K.ldim, convert(TB, K.B)) -end """ Allocation-free reshape @@ -154,12 +164,7 @@ function mul_vectrick!(x::AbstractVecOrMat, A::IKP, v::AbstractVecOrMat) return x end function mul_vectrick!( - x::AbstractVecOrMat, - A::IKP, - v::AbstractVecOrMat, - alpha::Number, - beta::Number, -) + x::AbstractVecOrMat, A::IKP, v::AbstractVecOrMat, alpha::Number, beta::Number) N = A.B c, d = size(N) @@ -169,39 +174,46 @@ function mul_vectrick!( return x end -_matmul!(C::AbstractVecOrMat, A::IKP, B::AbstractVecOrMat) = mul_vectrick!(C, A, B) mul!(C::AbstractMatrix, A::IKP, B::AbstractMatrix) = mul_vectrick!(C, A, B) mul!(C::AbstractMatrix, A::IKP, B::Adjoint{T,<:AbstractMatrix{T}}) where {T} = mul_vectrick!(C, A, B) mul!(C::AbstractVector, A::IKP, B::AbstractVector) = mul_vectrick!(C, A, B) -_matmul!( - C::AbstractVecOrMat{T}, - A::IKP{T}, - B::AbstractVecOrMat{T}, -) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C, A, B) -_matmul!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::IKP) = _matmul!(C', B', A') -_matmul!( - C::AbstractVecOrMat{T}, - A::AbstractVecOrMat{T}, - B::IKP{T}, -) where {T<:LinearAlgebra.BlasFloat} = _matmul!(C', B', A') - -_matmul!(C::AbstractVecOrMat, A::IKP, B::AbstractVecOrMat, alpha::Number, beta::Number) = - mul_vectrick!(C, A, B, alpha, beta) -_matmul!( - C::AbstractVecOrMat{T}, - A::IKP{T}, - B::AbstractVecOrMat{T}, - alpha::Number, - beta::Number, -) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C, A, B, alpha, beta) -_matmul!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::IKP, alpha::Number, beta::Number) = - mul_vectrick!(C', B', A', alpha, beta) -_matmul!( - C::AbstractVecOrMat{T}, - A::AbstractVecOrMat{T}, - B::IKP{T}, - alpha::Number, - beta::Number, -) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C', B', A', alpha, beta) +for TC in [:AbstractVector, :AbstractMatrix] + @eval mul!(C::$TC, A::IKP, B::$TC) = mul_vectrick!(C, A, B) + @eval mul!(C::$TC, A::IKP, B::Adjoint{T,<:$TC{T}}) where {T} = mul_vectrick!(C, A, B) + @eval mul!(C::$TC, A::IKP, B::$TC, alpha::Number, beta::Number) = + mul_vectrick!(C, A, B, alpha, beta) + + @eval _matmul!(C::$TC, A::IKP, B::$TC) = mul_vectrick!(C, A, B) + @eval _matmul!( + C::$TC{T}, + A::IKP{T}, + B::$TC{T}, + ) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C, A, B) + @eval _matmul!(C::$TC, A::$TC, B::IKP) = _matmul!(C', B', A') + @eval _matmul!( + C::$TC{T}, + A::$TC{T}, + B::IKP{T}, + ) where {T<:LinearAlgebra.BlasFloat} = _matmul!(C', B', A') + + @eval _matmul!(C::$TC, A::IKP, B::$TC, alpha::Number, beta::Number) = + mul_vectrick!(C, A, B, alpha, beta) + @eval _matmul!( + C::$TC{T}, + A::IKP{T}, + B::$TC{T}, + alpha::Number, + beta::Number, + ) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C, A, B, alpha, beta) + @eval _matmul!(C::$TC, A::$TC, B::IKP, alpha::Number, beta::Number) = + mul_vectrick!(C', B', A', alpha, beta) + @eval _matmul!( + C::$TC{T}, + A::$TC{T}, + B::IKP{T}, + alpha::Number, + beta::Number, + ) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C', B', A', alpha, beta) +end diff --git a/test/runtests.jl b/test/runtests.jl index 9e948bd3d..ca6ec412b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,9 @@ const GROUP = get(ENV, "GROUP", "All") @timedsafetestset "Smoothing" begin include("smoothing.jl") end + @timedsafetestset "IsometricKroneckerProduct" begin + include("core/kronecker.jl") + end end end