From 1046d55a4a8bfdea0efc362bbf64f3f3c7de7813 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 3 Jul 2024 09:31:25 +0200 Subject: [PATCH] Implement promote_operation (#37) * Implement promote_operation * Fixes * Add tests * Fix format --- src/MultivariateBases.jl | 6 +++ src/arithmetic.jl | 92 +++++++++++++++++++++++++++++++--------- src/monomial.jl | 52 ++++++++++++++++++++--- test/hermite.jl | 8 ++++ test/runtests.jl | 29 +++++++++---- 5 files changed, 153 insertions(+), 34 deletions(-) diff --git a/src/MultivariateBases.jl b/src/MultivariateBases.jl index 74590f4..c03ac5a 100644 --- a/src/MultivariateBases.jl +++ b/src/MultivariateBases.jl @@ -22,6 +22,12 @@ MP.monomial_type(::Type{<:Algebra{B}}) where {B} = MP.monomial_type(B) function MP.polynomial_type(::Type{<:Algebra{B}}, ::Type{T}) where {B,T} return MP.polynomial_type(B, T) end +function MA.promote_operation( + ::typeof(SA.basis), + ::Type{<:Algebra{B}}, +) where {B} + return B +end SA.basis(a::Algebra) = a.basis #Base.:(==)(::Algebra{BT1,B1,M}, ::Algebra{BT2,B2,M}) where {BT1,B1,BT2,B2,M} = true diff --git a/src/arithmetic.jl b/src/arithmetic.jl index 8c1c5ea..e613c3d 100644 --- a/src/arithmetic.jl +++ b/src/arithmetic.jl @@ -2,26 +2,76 @@ const _APL = MP.AbstractPolynomialLike # We don't define it for all `AlgebraElement` as this would be type piracy const _AE = SA.AlgebraElement{<:Algebra} -Base.:(+)(p::_APL, q::_AE) = +(p, MP.polynomial(q)) -Base.:(+)(p::_AE, q::_APL) = +(MP.polynomial(p), q) -Base.:(-)(p::_APL, q::_AE) = -(p, MP.polynomial(q)) -Base.:(-)(p::_AE, q::_APL) = -(MP.polynomial(p), q) - -Base.:(+)(p, q::_AE) = +(constant_algebra_element(typeof(SA.basis(q)), p), q) -function Base.:(+)(p::_AE, q) - return +(MP.polynomial(p), constant_algebra_element(typeof(SA.basis(p)), q)) -end -function Base.:(-)(p, q::_AE) - return -(constant_algebra_element(typeof(SA.basis(q)), p), MP.polynomial(q)) -end -function Base.:(-)(p::_AE, q) - return -(MP.polynomial(p), constant_algebra_element(typeof(SA.basis(p)), q)) +for op in [:+, :-, :*] + @eval begin + function MA.promote_operation( + ::typeof($op), + ::Type{P}, + ::Type{Q}, + ) where {P<:_APL,Q<:_AE} + return MA.promote_operation($op, P, MP.polynomial_type(Q)) + end + Base.$op(p::_APL, q::_AE) = $op(p, MP.polynomial(q)) + function MA.promote_operation( + ::typeof($op), + ::Type{P}, + ::Type{Q}, + ) where {P<:_AE,Q<:_APL} + return MA.promote_operation($op, MP.polynomial_type(P), Q) + end + Base.$op(p::_AE, q::_APL) = $op(MP.polynomial(p), q) + # Break ambiguity between the two defined below and the generic one in SA + function MA.promote_operation( + ::typeof($op), + ::Type{P}, + ::Type{Q}, + ) where {P<:_AE,Q<:_AE} + return SA.algebra_promote_operation($op, P, Q) + end + function Base.$op(p::_AE, q::_AE) + return MA.operate_to!(SA._preallocate_output($op, p, q), $op, p, q) + end + end end - -function Base.:(+)(p::_AE, q::_AE) - return MA.operate_to!(SA._preallocate_output(+, p, q), +, p, q) -end - -function Base.:(-)(p::_AE, q::_AE) - return MA.operate_to!(SA._preallocate_output(-, p, q), -, p, q) +for op in [:+, :-] + @eval begin + function MA.promote_operation( + ::typeof($op), + ::Type{P}, + ::Type{Q}, + ) where {P,Q<:_AE} + I = MA.promote_operation(implicit, Q) + return MA.promote_operation( + $op, + constant_algebra_element_type( + MA.promote_operation(SA.basis, I), + P, + ), + I, + ) + end + function Base.$op(p, q::_AE) + i = implicit(q) + return $op(constant_algebra_element(typeof(SA.basis(i)), p), i) + end + function MA.promote_operation( + ::typeof($op), + ::Type{P}, + ::Type{Q}, + ) where {P<:_AE,Q} + I = MA.promote_operation(implicit, P) + return MA.promote_operation( + $op, + I, + constant_algebra_element_type( + MA.promote_operation(SA.basis, I), + Q, + ), + ) + end + function Base.$op(p::_AE, q) + i = implicit(p) + return $op(i, constant_algebra_element(typeof(SA.basis(i)), q)) + end + end end diff --git a/src/monomial.jl b/src/monomial.jl index 44d6f2d..8c99ad0 100644 --- a/src/monomial.jl +++ b/src/monomial.jl @@ -127,6 +127,21 @@ end implicit_basis(::SubBasis{B,M}) where {B,M} = FullBasis{B,M}() implicit_basis(basis::FullBasis) = basis +function implicit(a::SA.AlgebraElement) + basis = implicit_basis(SA.basis(a)) + return algebra_element(SA.coeffs(a, basis), basis) +end + +function MA.promote_operation( + ::typeof(implicit), + ::Type{E}, +) where {AG,T,E<:SA.AlgebraElement{AG,T}} + BT = MA.promote_operation(implicit_basis, MA.promote_operation(SA.basis, E)) + A = MA.promote_operation(algebra, BT) + M = MP.monomial_type(BT) + return SA.AlgebraElement{A,T,SA.SparseCoefficients{M,T,Vector{M},Vector{T}}} +end + function MA.promote_operation( ::typeof(implicit_basis), ::Type{<:Union{FullBasis{B,M},SubBasis{B,M}}}, @@ -201,6 +216,14 @@ end _one_if_type(α) = α _one_if_type(::Type{T}) where {T} = one(T) +function constant_algebra_element_type( + ::Type{BT}, + ::Type{T}, +) where {B,M,BT<:FullBasis{B,M},T} + A = MA.promote_operation(algebra, BT) + return SA.AlgebraElement{A,T,SA.SparseCoefficients{M,T,Vector{M},Vector{T}}} +end + function constant_algebra_element(::Type{FullBasis{B,M}}, α) where {B,M} return algebra_element( sparse_coefficients( @@ -210,6 +233,14 @@ function constant_algebra_element(::Type{FullBasis{B,M}}, α) where {B,M} ) end +function constant_algebra_element_type( + ::Type{B}, + ::Type{T}, +) where {B<:SubBasis,T} + A = MA.promote_operation(algebra, B) + return SA.AlgebraElement{A,T,Vector{T}} +end + function constant_algebra_element(::Type{<:SubBasis{B,M}}, α) where {B,M} return algebra_element( [_one_if_type(α)], @@ -377,14 +408,25 @@ function MP.polynomial_type(::Type{FullBasis{B,M}}, ::Type{T}) where {T,B,M} return MP.polynomial_type(M, _promote_coef(T, B)) end +_vec(v::Vector) = v +_vec(v::AbstractVector) = collect(v) + # Adapted from SA to incorporate `_promote_coef` function SA.coeffs( cfs, - source::MonomialIndexedBasis{B}, - target::MonomialIndexedBasis{Monomial}, -) where {B} + source::MonomialIndexedBasis{B1}, + target::MonomialIndexedBasis{B2}, +) where {B1,B2} source === target && return cfs source == target && return cfs - res = SA.zero_coeffs(_promote_coef(valtype(cfs), B), target) - return SA.coeffs!(res, cfs, source, target) + if B1 === B2 && target isa FullBasis + # The defaults initialize to zero and then sums which promotes + # `JuMP.VariableRef` to `JuMP.AffExpr` + return SA.SparseCoefficients(_vec(source.monomials), _vec(cfs)) + elseif B2 === Monomial + res = SA.zero_coeffs(_promote_coef(valtype(cfs), B1), target) + return SA.coeffs!(res, cfs, source, target) + else + error("Convertion from `$source` to `$target` not implemented yet") + end end diff --git a/test/hermite.jl b/test/hermite.jl index e0fa720..c72c97e 100644 --- a/test/hermite.jl +++ b/test/hermite.jl @@ -26,6 +26,7 @@ end end @testset "Coefficients" begin + @polyvar x coefficient_test( MB.ProbabilistsHermite, [4, 6, 6, 1, 9, 1, 1, 1]; @@ -44,4 +45,11 @@ end 1.0, ]), ) + M = typeof(x^2) + mono = MB.FullBasis{MB.Monomial,M}() + basis = MB.FullBasis{MB.PhysicistsHermite,M}() + err = ErrorException( + "Convertion from `$mono` to `$basis` not implemented yet", + ) + @test_throws err SA.coeffs(MB.algebra_element(x + 1), basis) end diff --git a/test/runtests.jl b/test/runtests.jl index 6fd543e..569f942 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,14 +7,26 @@ const MB = MultivariateBases using LinearAlgebra using DynamicPolynomials +function _test_op(op, args...) + result = @inferred op(args...) + @test typeof(result) == MA.promote_operation(op, typeof.(args)...) + return result +end + +function _test_basis(basis) + B = typeof(basis) + @test typeof(MB.algebra(basis)) == MA.promote_operation(MB.algebra, B) + @test typeof(MB.constant_algebra_element(B, 1)) == + MB.constant_algebra_element_type(B, Int) +end + function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree) @polyvar x[1:2] M = typeof(prod(x)) full_basis = FullBasis{B,M}() + _test_basis(full_basis) @test sprint(show, MB.algebra(full_basis)) == "Polynomial algebra of $B basis" - @test typeof(MB.algebra(full_basis)) == - MA.promote_operation(MB.algebra, typeof(full_basis)) for basis in [ maxdegree_basis(full_basis, x, degree), explicit_basis_covering( @@ -26,8 +38,7 @@ function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree) MB.SubBasis{ScaledMonomial}(monomials(x, 0:degree)), ), ] - @test typeof(MB.algebra(basis)) == - MA.promote_operation(MB.algebra, typeof(basis)) + _test_basis(basis) @test basis isa MB.explicit_basis_type(typeof(full_basis)) for i in eachindex(basis) mono = basis.monomials[i] @@ -47,6 +58,8 @@ function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree) @test length(empty_basis(typeof(basis))) == 0 @test polynomial_type(basis, Float64) == polynomial_type(x[1], Float64) #@test polynomial(i -> 0.0, basis) isa polynomial_type(basis, Float64) + a = MB.algebra_element(ones(length(basis)), basis) + _test_op(MB.implicit, a) end mono = x[1]^2 * x[2]^3 p = MB.Polynomial{B}(mono) @@ -77,10 +90,10 @@ function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree) const_poly = MB.Polynomial{B}(const_mono) const_alg_el = MB.algebra_element(const_poly) for other in (const_mono, 1, const_alg_el) - @test other + const_alg_el ≈ 2 * other - @test const_alg_el + other ≈ 2 * other - @test iszero(other - const_alg_el) - @test iszero(const_alg_el - other) + @test _test_op(+, other, const_alg_el) ≈ _test_op(*, 2, other) + @test _test_op(+, const_alg_el, other) ≈ _test_op(*, 2, other) + @test iszero(_test_op(-, other, const_alg_el)) + @test iszero(_test_op(-, const_alg_el, other)) end @test typeof(MB.sparse_coefficients(sum(x))) == MA.promote_operation(MB.sparse_coefficients, typeof(sum(x)))