From 92bd54dc6b162f6696939958ad06db04e5ea44f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Mon, 27 May 2024 17:05:24 +0200 Subject: [PATCH] Fixes for SumOfSquares --- src/algebra_elts.jl | 6 ++++++ src/arithmetic.jl | 20 +++++++++++++++++++- src/bases.jl | 25 +++++++++++++++++++++++++ src/sparse_coeffs.jl | 8 ++++++++ src/types.jl | 7 ++++++- 5 files changed, 64 insertions(+), 2 deletions(-) diff --git a/src/algebra_elts.jl b/src/algebra_elts.jl index 44954df..78ed23c 100644 --- a/src/algebra_elts.jl +++ b/src/algebra_elts.jl @@ -18,6 +18,12 @@ end (a::AlgebraElement)(x) = coeffs(a)[basis(a)[x]] Base.setindex!(a::AlgebraElement, v, idx) = a.coeffs[basis(a)[idx]] = v +function nonzero_pairs(a::AlgebraElement) + return Base.Generator(nonzero_pairs(coeffs(a))) do (k, v) + return basis(a)[k], v + end +end + # AlgebraElement specific functions function supp(a::AlgebraElement) diff --git a/src/arithmetic.jl b/src/arithmetic.jl index b78bf49..f5340dc 100644 --- a/src/arithmetic.jl +++ b/src/arithmetic.jl @@ -9,6 +9,10 @@ function _preallocate_output(op, args::Vararg{Any,N}) where {N} return similar(args[1], T) end +function MA.promote_operation(::typeof(similar), ::Type{<:Vector}, ::Type{T}) where {T} + return Vector{T} +end + # module structure: Base.:*(X::AlgebraElement, a::Number) = a * X @@ -25,12 +29,21 @@ function Base.:div(X::AlgebraElement, a::Number) return MA.operate_to!(_preallocate_output(div, X, a), div, X, a) end +function MA.promote_operation( + op::Union{typeof(+),typeof(-)}, + ::Type{AlgebraElement{A,T,VT}}, + ::Type{AlgebraElement{A,S,VS}}, +) where {A,T,VT,S,VS} + U = MA.promote_operation(op, T, S) + return AlgebraElement{A,U,MA.promote_operation(similar, VT, U)} +end function Base.:+(X::AlgebraElement, Y::AlgebraElement) return MA.operate_to!(_preallocate_output(+, X, Y), +, X, Y) end function Base.:-(X::AlgebraElement, Y::AlgebraElement) return MA.operate_to!(_preallocate_output(-, X, Y), -, X, Y) end + function Base.:*(X::AlgebraElement, Y::AlgebraElement) return MA.operate_to!(_preallocate_output(*, X, Y), *, X, Y) end @@ -92,7 +105,7 @@ function MA.operate_to!( X::AlgebraElement, Y::AlgebraElement, ) - @assert parent(res) === parent(X) === parent(Y) + @assert parent(res) == parent(X) == parent(Y) MA.operate_to!(coeffs(res), -, coeffs(X), coeffs(Y)) return res end @@ -138,3 +151,8 @@ function unsafe_push!(a::Vector, k, v) a[k] = MA.add!!(a[k], v) return a end + +function unsafe_push!(a::AlgebraElement, k, v) + unsafe_push!(coeffs(a), basis(a)[k], v) + return a +end diff --git a/src/bases.jl b/src/bases.jl index 813b880..97d9ad5 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -77,3 +77,28 @@ function coeffs!(res, cfs, source::AbstractBasis, target::AbstractBasis) end return res end + +""" + adjoint_coeffs(cfs, source, target) +Return `A' * cfs` where `A` is the linear map applied by +`coeffs`. +""" +function adjoint_coeffs(cfs, source::AbstractBasis, target::AbstractBasis) + source === target && return cfs + source == target && return cfs + res = zero_coeffs(valtype(cfs), source) + return adjoint_coeffs!(res, cfs, source, target) +end + +function adjoint_coeffs!(res, cfs, source::AbstractBasis, target::AbstractBasis) + MA.operate!(zero, res) + for (k, v) in nonzero_pairs(cfs) + x = target[k] + # If `x` is not in `source` then the corresponding row in `A` is zero + # so the column in `A'` is zero hence we can ignore it. + if x in source + res[source[x]] += v + end + end + return res +end diff --git a/src/sparse_coeffs.jl b/src/sparse_coeffs.jl index bd828c9..c9ebeb0 100644 --- a/src/sparse_coeffs.jl +++ b/src/sparse_coeffs.jl @@ -42,6 +42,14 @@ function Base.zero(sc::SparseCoefficients) return SparseCoefficients(empty(keys(sc)), empty(values(sc))) end +function MA.promote_operation( + ::typeof(similar), + ::Type{SparseCoefficients{K,V,Vk,Vv}}, + ::Type{T}, +) where {K,V,Vk,Vv,T} + return SparseCoefficients{K,T,Vk,MA.promote_operation(similar, Vv, T)} +end + function Base.similar(s::SparseCoefficients, ::Type{T} = valtype(s)) where {T} return SparseCoefficients(similar(s.basis_elements), similar(s.values, T)) end diff --git a/src/types.jl b/src/types.jl index d62037d..947b33a 100644 --- a/src/types.jl +++ b/src/types.jl @@ -25,20 +25,25 @@ struct StarAlgebra{O,T,B<:AbstractBasis{T}} <: AbstractStarAlgebra{O,T} end basis(A::StarAlgebra) = A.basis +MA.promote_operation(::typeof(basis), ::Type{StarAlgebra{O,T,B}}) where {O,T,B} = B object(A::StarAlgebra) = A.object +function algebra end + struct AlgebraElement{A,T,V} <: MA.AbstractMutable coeffs::V parent::A end Base.parent(a::AlgebraElement) = a.parent -Base.eltype(a::AlgebraElement) = valtype(coeffs(a)) +Base.eltype(a::AlgebraElement) = eltype(typeof(a)) +Base.eltype(::Type{<:AlgebraElement{A,T}}) where {A,T} = T coeffs(a::AlgebraElement) = a.coeffs function coeffs(x::AlgebraElement, b::AbstractBasis) return coeffs(coeffs(x), basis(x), b) end basis(a::AlgebraElement) = basis(parent(a)) +MA.promote_operation(::typeof(basis), ::Type{<:AlgebraElement{A}}) where {A} = MA.promote_operation(basis, A) function AlgebraElement(coeffs, A::AbstractStarAlgebra) _sanity_checks(coeffs, A)