From 61f676674e37048623be6f1aefd6a12de67ffeb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 20 Nov 2024 22:16:54 +0100 Subject: [PATCH] Faster vector substitution --- src/subs.jl | 40 +++++++++++++++++++++++++--------------- test/runtests.jl | 16 ++++++++++++++++ 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/src/subs.jl b/src/subs.jl index f856522..7bdac8e 100644 --- a/src/subs.jl +++ b/src/subs.jl @@ -136,8 +136,10 @@ function _add_variables!(p::PolyType, q::PolyType) return p end -function monoeval(z::Vector{Int}, vals::AbstractVector) - @assert length(z) == length(vals) +function _mono_eval(z::Vector{Int}, vals::AbstractVector) + if length(z) != length(vals) + error("") + end if isempty(z) return one(eltype(vals))^1 end @@ -154,24 +156,24 @@ function monoeval(z::Vector{Int}, vals::AbstractVector) return val end -_subs(st, ::Variable, vals) = monoeval([1], vals::AbstractVector) -_subs(st, m::Monomial, vals) = monoeval(m.z, vals::AbstractVector) -function _subs(st, t::_Term, vals) - return MP.coefficient(t) * monoeval(MP.monomial(t).z, vals::AbstractVector) +MP.substitute(::MP.AbstractSubstitutionType, ::Variable, vals::AbstractVector) = _mono_eval((1,), vals) +MP.substitute(::MP.AbstractSubstitutionType, m::Monomial, vals::AbstractVector) = _mono_eval(m.z, vals) +function MP.substitute(st::MP.AbstractSubstitutionType, t::_Term, vals::AbstractVector) + return MP.coefficient(t) * MP.substitute(st, MP.monomial(t), vals) end -function _subs( +function MP.substitute( ::MP.Eval, p::Polynomial{V,M,T}, vals::AbstractVector{S}, ) where {V,M,T,S} # I need to check for iszero otherwise I get : ArgumentError: reducing over an empty collection is not allowed if iszero(p) - zero(Base.promote_op(*, S, T)) + zero(MA.promote_operation(*, S, T)) else - sum(i -> p.a[i] * monoeval(p.x.Z[i], vals), eachindex(p.a)) + sum(i -> p.a[i] * _mono_eval(p.x.Z[i], vals), eachindex(p.a)) end end -function _subs( +function MP.substitute( ::MP.Subs, p::Polynomial{V,M,T}, vals::AbstractVector{S}, @@ -182,7 +184,7 @@ function _subs( mergevars_of(Variable{V,M}, vals)[1], ) for i in eachindex(p.a) - MA.operate!(+, q, p.a[i] * monoeval(p.x.Z[i], vals)) + MA.operate!(+, q, p.a[i] * _mono_eval(p.x.Z[i], vals)) end return q end @@ -197,12 +199,20 @@ function MA.promote_operation( return MA.promote_operation(*, U, Monomial{V,M}) end +function MP.substitute( + st::MP.AbstractSubstitutionType, + p::PolyType, + s::MP.AbstractSubstitution..., +) + return MP.substitute(st, p, subsmap(st, MP.variables(p), s)) +end + function MP.substitute( st::MP.AbstractSubstitutionType, p::PolyType, s::MP.Substitutions, ) - return _subs(st, p, subsmap(st, MP.variables(p), s)) + return MP.substitute(st, p, subsmap(st, MP.variables(p), s)) end (v::Variable)(s::MP.AbstractSubstitution...) = MP.substitute(MP.Eval(), v, s) @@ -215,20 +225,20 @@ function (p::Monomial)(x::NTuple{N,<:Number}) where {N} return MP.substitute(MP.Eval(), p, variables(p) => x) end function (p::Monomial)(x::AbstractVector{<:Number}) - return MP.substitute(MP.Eval(), p, variables(p) => x) + return MP.substitute(MP.Eval(), p, x) end (p::Monomial)(x::Number...) = MP.substitute(MP.Eval(), p, variables(p) => x) function (p::_Term)(x::NTuple{N,<:Number}) where {N} return MP.substitute(MP.Eval(), p, variables(p) => x) end function (p::_Term)(x::AbstractVector{<:Number}) - return MP.substitute(MP.Eval(), p, variables(p) => x) + return MP.substitute(MP.Eval(), p, x) end (p::_Term)(x::Number...) = MP.substitute(MP.Eval(), p, variables(p) => x) function (p::Polynomial)(x::NTuple{N,<:Number}) where {N} return MP.substitute(MP.Eval(), p, variables(p) => x) end function (p::Polynomial)(x::AbstractVector{<:Number}) - return MP.substitute(MP.Eval(), p, variables(p) => x) + return MP.substitute(MP.Eval(), p, x) end (p::Polynomial)(x::Number...) = MP.substitute(MP.Eval(), p, variables(p) => x) diff --git a/test/runtests.jl b/test/runtests.jl index 7226e2d..c9f1a49 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,23 @@ using MultivariatePolynomials using Test using LinearAlgebra +function alloc_test_lt(f, n) + f() # compile + @test n >= @allocated f() +end + # TODO move to MP +@testset "See https://github.com/jump-dev/SumOfSquares.jl/issues/388" begin + @polyvar x[1:3] + p = sum(x) + v = map(_ -> 1, x) + # I get 208 but let's give some margin + alloc_test_lt(() -> substitute(Eval(), p, x => v), 300) + alloc_test_lt(() -> p(x => v), 300) + alloc_test_lt(() -> substitute(Eval(), p, v), 0) + alloc_test_lt(() -> p(v), 0) +end + @testset "Issue #70" begin @ncpolyvar y0 y1 x0 x1 p = x1 * x0 * x1