diff --git a/src/bases.jl b/src/bases.jl index e9d738f..2cd8fb0 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -14,6 +14,12 @@ Base.eltype(b::AbstractBasis) = eltype(typeof(b)) Base.keytype(::Type{<:AbstractBasis{T,I}}) where {T,I} = I Base.keytype(b::AbstractBasis) = keytype(typeof(b)) +key_type(b) = keytype(b) +# `keytype(::Type{SparseVector{V,K}})` is not defined so it falls +# back to `keytype{::Type{<:AbstractArray})` which returns `Int`. +key_type(::Type{SparseArrays.SparseVector{V,K}}) where {V,K} = K +key_type(v::SparseArrays.SparseVector) = key_type(typeof(v)) + """ ImplicitBasis{T,I} Implicit bases are not stored in memory and can be potentially infinite. diff --git a/src/mstructures.jl b/src/mstructures.jl index fdd155a..3d59e25 100644 --- a/src/mstructures.jl +++ b/src/mstructures.jl @@ -41,16 +41,13 @@ struct UnsafeAddMul{M<:Union{typeof(*),MultiplicativeStructure}} structure::M end -function MA.operate_to!( - res::SparseCoefficients, - ms::MultiplicativeStructure, - v::AbstractCoefficients, - w::AbstractCoefficients, -) +function MA.operate_to!(res, ms::MultiplicativeStructure, v, w) + if res === v || res === w + throw(ArgumentError("No alias allowed")) + end MA.operate!(zero, res) res = MA.operate!(UnsafeAddMul(ms), res, v, w) - __canonicalize!(res) - return res + return __canonicalize!(res) end function MA.operate!( @@ -66,36 +63,16 @@ function MA.operate!( return mc end -function MA.operate!( - ms::UnsafeAddMul, - res::SparseCoefficients, - v::AbstractCoefficients, - w::AbstractCoefficients, -) +function MA.operate!(ms::UnsafeAddMul, res, v, w) for (kv, a) in pairs(v) for (kw, b) in pairs(w) - c = ms.structure(kv, kw) # ::AbstractCoefficients + c = ms.structure(kv, kw) MA.operate!(UnsafeAddMul(*), res, a * b, c) end end return res end -function MA.operate_to!( - res::AbstractVector, - ms::MultiplicativeStructure, - X::AbstractVector, - Y::AbstractVector, -) - if res === X || res === Y - throw(ArgumentError("No alias allowed")) - end - MA.operate!(zero, res) - MA.operate!(UnsafeAddMul(ms), res, X, Y) - res = __canonicalize!(res) - return res -end - __canonicalize!(sv::SparseVector) = dropzeros!(sv) __canonicalize!(v::AbstractVector) = v struct DiracMStructure{Op} <: MultiplicativeStructure diff --git a/src/types.jl b/src/types.jl index a74ae90..0a599e6 100644 --- a/src/types.jl +++ b/src/types.jl @@ -43,9 +43,17 @@ struct AlgebraElement{A,T,V} <: MA.AbstractMutable parent::A end -function AlgebraElement(coeffs::AbstractVector, A::AbstractStarAlgebra) +function _sanity_checks(coeffs, A::AbstractStarAlgebra) + @assert key_type(coeffs) == keytype(basis(A)) +end +function _sanity_checks(coeffs::AbstractVector, A::AbstractStarAlgebra) + @assert key_type(coeffs) == keytype(basis(A)) @assert Base.haslength(basis(A)) @assert length(coeffs) == length(basis(A)) +end + +function AlgebraElement(coeffs, A::AbstractStarAlgebra) + _sanity_checks(coeffs, A) return AlgebraElement{typeof(A),valtype(coeffs),typeof(coeffs)}(coeffs, A) end