Skip to content

Commit

Permalink
Faster vector substitution
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Nov 20, 2024
1 parent e5df766 commit 61f6766
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 15 deletions.
40 changes: 25 additions & 15 deletions src/subs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ function _add_variables!(p::PolyType, q::PolyType)
return p
end

function monoeval(z::Vector{Int}, vals::AbstractVector)
@assert length(z) == length(vals)
function _mono_eval(z::Vector{Int}, vals::AbstractVector)
if length(z) != length(vals)
error("")
end
if isempty(z)
return one(eltype(vals))^1
end
Expand All @@ -154,24 +156,24 @@ function monoeval(z::Vector{Int}, vals::AbstractVector)
return val
end

_subs(st, ::Variable, vals) = monoeval([1], vals::AbstractVector)
_subs(st, m::Monomial, vals) = monoeval(m.z, vals::AbstractVector)
function _subs(st, t::_Term, vals)
return MP.coefficient(t) * monoeval(MP.monomial(t).z, vals::AbstractVector)
MP.substitute(::MP.AbstractSubstitutionType, ::Variable, vals::AbstractVector) = _mono_eval((1,), vals)
MP.substitute(::MP.AbstractSubstitutionType, m::Monomial, vals::AbstractVector) = _mono_eval(m.z, vals)
function MP.substitute(st::MP.AbstractSubstitutionType, t::_Term, vals::AbstractVector)
return MP.coefficient(t) * MP.substitute(st, MP.monomial(t), vals)
end
function _subs(
function MP.substitute(
::MP.Eval,
p::Polynomial{V,M,T},
vals::AbstractVector{S},
) where {V,M,T,S}
# I need to check for iszero otherwise I get : ArgumentError: reducing over an empty collection is not allowed
if iszero(p)
zero(Base.promote_op(*, S, T))
zero(MA.promote_operation(*, S, T))
else
sum(i -> p.a[i] * monoeval(p.x.Z[i], vals), eachindex(p.a))
sum(i -> p.a[i] * _mono_eval(p.x.Z[i], vals), eachindex(p.a))
end
end
function _subs(
function MP.substitute(
::MP.Subs,
p::Polynomial{V,M,T},
vals::AbstractVector{S},
Expand All @@ -182,7 +184,7 @@ function _subs(
mergevars_of(Variable{V,M}, vals)[1],
)
for i in eachindex(p.a)
MA.operate!(+, q, p.a[i] * monoeval(p.x.Z[i], vals))
MA.operate!(+, q, p.a[i] * _mono_eval(p.x.Z[i], vals))
end
return q
end
Expand All @@ -197,12 +199,20 @@ function MA.promote_operation(
return MA.promote_operation(*, U, Monomial{V,M})
end

function MP.substitute(
st::MP.AbstractSubstitutionType,
p::PolyType,
s::MP.AbstractSubstitution...,
)
return MP.substitute(st, p, subsmap(st, MP.variables(p), s))
end

function MP.substitute(
st::MP.AbstractSubstitutionType,
p::PolyType,
s::MP.Substitutions,
)
return _subs(st, p, subsmap(st, MP.variables(p), s))
return MP.substitute(st, p, subsmap(st, MP.variables(p), s))
end

(v::Variable)(s::MP.AbstractSubstitution...) = MP.substitute(MP.Eval(), v, s)
Expand All @@ -215,20 +225,20 @@ function (p::Monomial)(x::NTuple{N,<:Number}) where {N}
return MP.substitute(MP.Eval(), p, variables(p) => x)
end
function (p::Monomial)(x::AbstractVector{<:Number})
return MP.substitute(MP.Eval(), p, variables(p) => x)
return MP.substitute(MP.Eval(), p, x)
end
(p::Monomial)(x::Number...) = MP.substitute(MP.Eval(), p, variables(p) => x)
function (p::_Term)(x::NTuple{N,<:Number}) where {N}
return MP.substitute(MP.Eval(), p, variables(p) => x)
end
function (p::_Term)(x::AbstractVector{<:Number})
return MP.substitute(MP.Eval(), p, variables(p) => x)
return MP.substitute(MP.Eval(), p, x)
end
(p::_Term)(x::Number...) = MP.substitute(MP.Eval(), p, variables(p) => x)
function (p::Polynomial)(x::NTuple{N,<:Number}) where {N}
return MP.substitute(MP.Eval(), p, variables(p) => x)
end
function (p::Polynomial)(x::AbstractVector{<:Number})
return MP.substitute(MP.Eval(), p, variables(p) => x)
return MP.substitute(MP.Eval(), p, x)
end
(p::Polynomial)(x::Number...) = MP.substitute(MP.Eval(), p, variables(p) => x)
16 changes: 16 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,23 @@ using MultivariatePolynomials
using Test
using LinearAlgebra

function alloc_test_lt(f, n)
f() # compile
@test n >= @allocated f()
end

# TODO move to MP
@testset "See https://github.com/jump-dev/SumOfSquares.jl/issues/388" begin
@polyvar x[1:3]
p = sum(x)
v = map(_ -> 1, x)
# I get 208 but let's give some margin
alloc_test_lt(() -> substitute(Eval(), p, x => v), 300)
alloc_test_lt(() -> p(x => v), 300)
alloc_test_lt(() -> substitute(Eval(), p, v), 0)
alloc_test_lt(() -> p(v), 0)
end

@testset "Issue #70" begin
@ncpolyvar y0 y1 x0 x1
p = x1 * x0 * x1
Expand Down

0 comments on commit 61f6766

Please sign in to comment.