From 2be328e5188f880ca788285c58a1afdd8884180d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 12 Mar 2024 11:15:59 +0530 Subject: [PATCH 1/3] feat: move fast_substitute to Symbolics, implement SII.symbolic_evaluate --- Project.toml | 2 +- src/variable.jl | 92 +++++++++++++++++++++++ test/symbolic_indexing_interface_trait.jl | 8 ++ 3 files changed, 101 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5f4228de6..95cf62dc0 100644 --- a/Project.toml +++ b/Project.toml @@ -81,7 +81,7 @@ SciMLBase = "2" Setfield = "1" SpecialFunctions = "2" StaticArrays = "1.1" -SymbolicIndexingInterface = "0.3" +SymbolicIndexingInterface = "0.3.12" SymbolicLimits = "0.2.0" SymbolicUtils = "1.4" julia = "1.10" diff --git a/src/variable.jl b/src/variable.jl index aba9ba502..e4b0eaf42 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -410,6 +410,98 @@ end SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val) +function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symbolic, Equation, Inequality}, d::Dict; kwargs...) + fixpoint_sub(ex, d; kwargs...) +end + +function fixpoint_sub(x, dict; operator = Nothing) + y = fast_substitute(x, dict; operator) + while !isequal(x, y) + y = x + x = fast_substitute(y, dict; operator) + end + + return x +end + +const Eq = Union{Equation, Inequality} +function fast_substitute(eq::Eq, subs; operator = Nothing) + if eq isa Inequality + Inequality(fast_substitute(eq.lhs, subs; operator), + fast_substitute(eq.rhs, subs; operator), + eq.relational_op) + else + Equation(fast_substitute(eq.lhs, subs; operator), + fast_substitute(eq.rhs, subs; operator)) + end +end +function fast_substitute(eq::T, subs::Pair; operator = Nothing) where {T <: Eq} + T(fast_substitute(eq.lhs, subs; operator), fast_substitute(eq.rhs, subs; operator)) +end +function fast_substitute(eqs::AbstractArray, subs; operator = Nothing) + fast_substitute.(eqs, (subs,); operator) +end +function fast_substitute(eqs::AbstractArray, subs::Pair; operator = Nothing) + fast_substitute.(eqs, (subs,); operator) +end +for (exprType, subsType) in Iterators.product((Num, Symbolics.Arr), (Any, Pair)) + @eval function fast_substitute(expr::$exprType, subs::$subsType; operator = Nothing) + fast_substitute(value(expr), subs; operator) + end +end +function fast_substitute(expr, subs; operator = Nothing) + if (_val = get(subs, expr, nothing)) !== nothing + return _val + end + istree(expr) || return expr + op = fast_substitute(operation(expr), subs; operator) + args = SymbolicUtils.unsorted_arguments(expr) + if !(op isa operator) + canfold = Ref(!(op isa Symbolic)) + args = let canfold = canfold + map(args) do x + x′ = fast_substitute(x, subs; operator) + canfold[] = canfold[] && !(x′ isa Symbolic) + x′ + end + end + canfold[] && return op(args...) + end + similarterm(expr, + op, + args, + symtype(expr); + metadata = metadata(expr)) +end +function fast_substitute(expr, pair::Pair; operator = Nothing) + a, b = pair + isequal(expr, a) && return b + if a isa AbstractArray + for (ai, bi) in zip(a, b) + expr = fast_substitute(expr, ai => bi; operator) + end + end + istree(expr) || return expr + op = fast_substitute(operation(expr), pair; operator) + args = SymbolicUtils.unsorted_arguments(expr) + if !(op isa operator) + canfold = Ref(!(op isa Symbolic)) + args = let canfold = canfold + map(args) do x + x′ = fast_substitute(x, pair; operator) + canfold[] = canfold[] && !(x′ isa Symbolic) + x′ + end + end + canfold[] && return op(args...) + end + similarterm(expr, + op, + args, + symtype(expr); + metadata = metadata(expr)) +end + function getparent(x, val=_fail) maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing) if maybe_parent !== nothing diff --git a/test/symbolic_indexing_interface_trait.jl b/test/symbolic_indexing_interface_trait.jl index 52d1579ae..6c8af1457 100644 --- a/test/symbolic_indexing_interface_trait.jl +++ b/test/symbolic_indexing_interface_trait.jl @@ -10,3 +10,11 @@ using SymbolicIndexingInterface @variables y[1:3] @test symbolic_type(y) == ArraySymbolic() @test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),)) + +@variables x y z +subs = Dict(x => 0.1, y => 2z) +subs2 = merge(subs, Dict(z => 2x+3)) + +@test symbolic_evaluate(x, subs) == 0.1 +@test isequal(symbolic_evaluate(y, subs), 2z) +@test symbolic_evaluate(y, subs2) == 6.4 From f0268e13689f2d61946d8cdb3eb00c3bed2495d5 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 25 Mar 2024 13:57:09 +0530 Subject: [PATCH 2/3] test: add tests for symbolic_evaluate to SII testset --- test/runtests.jl | 3 ++ ...ic_indexing_interface_symbolic_evaluate.jl | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 test/symbolic_indexing_interface_symbolic_evaluate.jl diff --git a/test/runtests.jl b/test/runtests.jl index fe58b97fd..cda10e6bf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,6 +63,9 @@ if GROUP == "All" || GROUP == "Core" || GROUP == "SymbolicIndexingInterface" @safetestset "SymbolicIndexingInterface Parameter Indexing Test" begin include("symbolic_indexing_interface_parameter_indexing.jl") end + @safetestset "SymbolicIndexingInterface Symbolic Evaluate Test" begin + include("symbolic_indexing_interface_symbolic_evaluate.jl") + end end if GROUP == "Downstream" diff --git a/test/symbolic_indexing_interface_symbolic_evaluate.jl b/test/symbolic_indexing_interface_symbolic_evaluate.jl new file mode 100644 index 000000000..251bef619 --- /dev/null +++ b/test/symbolic_indexing_interface_symbolic_evaluate.jl @@ -0,0 +1,36 @@ +using Symbolics +using SymbolicIndexingInterface +using Symbolics: Differential, Operator + +@variables t x(t) y(t) +@variables p[1:3, 1:3] q[1:3] + +bar(x, p) = p * x +@register_array_symbolic bar(x::AbstractVector, p::AbstractMatrix) begin + size = size(x) + eltype = promote_type(eltype(x), eltype(p)) +end + +D = Differential(t) + +expr1 = x + y + D(x) +@test isequal(symbolic_evaluate(expr1, Dict(x => 3)), 3 + y + D(3)) +@test isequal(symbolic_evaluate(expr1, Dict(x => 3); operator = Operator), 3 + y + D(x)) +@test isequal(symbolic_evaluate(expr1, Dict(x => 1, D(x) => 2)), y + 3) +@test symbolic_evaluate(expr1, Dict(x => 1, D(x) => 2, y => 3)) == 6 +@test isequal(symbolic_evaluate(expr1, Dict(x => 3, y => 3x), operator = Operator), 12 + D(x)) +@test symbolic_evaluate(expr1, Dict(x => 3, y => 3x, D(x) => 2)) == 14 + +expr2 = bar(q, p) +@test isequal(symbolic_evaluate(expr2, Dict(p => ones(3, 3))), bar(q, ones(3, 3))) +@test symbolic_evaluate(expr2, Dict(p => ones(3, 3), q => ones(3))) == 3ones(3) + +expr3 = bar(3q, 3p) +@test isequal(symbolic_evaluate(expr3, Dict(p => ones(3, 3))), bar(3q, 3ones(3, 3))) +@test symbolic_evaluate(expr3, Dict(p => ones(3, 3), q => ones(3))) == 27ones(3) + +expr4 = D(x) ~ 3x + y +@test isequal(symbolic_evaluate(expr4, Dict(x => 3)), D(3) ~ 9 + y) +@test isequal(symbolic_evaluate(expr4, Dict(x => 3); operator = Operator), D(x) ~ y + 9) +@test isequal(symbolic_evaluate(expr4, Dict(x => 1, D(x) => 2)), 2 ~ 3 + y) +@test isequal(symbolic_evaluate(expr4, Dict(x => 1, D(x) => 2, y => 3)), 2 ~ 6) From 33ce32b6304e22a59f0df54484bb75be93e94dec Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 25 Mar 2024 15:03:46 +0530 Subject: [PATCH 3/3] docs: add docs for fixpoint_sub and fast_substitute --- docs/src/manual/expression_manipulation.md | 2 ++ src/variable.jl | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/docs/src/manual/expression_manipulation.md b/docs/src/manual/expression_manipulation.md index 3dfbbd44b..102aa985c 100644 --- a/docs/src/manual/expression_manipulation.md +++ b/docs/src/manual/expression_manipulation.md @@ -31,4 +31,6 @@ Symbolics.coeff Symbolics.replace Symbolics.occursin Symbolics.filterchildren +Symbolics.fixpoint_sub +Symbolics.fast_substitute ``` diff --git a/src/variable.jl b/src/variable.jl index e4b0eaf42..d2b461d80 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -414,6 +414,17 @@ function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symboli fixpoint_sub(ex, d; kwargs...) end +""" + fixpoint_sub(expr, dict; operator = Nothing) + +Given a symbolic expression, equation or inequality `expr` perform the substitutions in +`dict` recursively until the expression does not change. Substitutions that depend on one +another will thus be recursively expanded. For example, +`fixpoint_sub(x, Dict(x => y, y => 3))` will return `3`. The `operator` keyword can be +specified to prevent substitution of expressions inside operators of the given type. + +See also: [`fast_substitute`](@ref). +""" function fixpoint_sub(x, dict; operator = Nothing) y = fast_substitute(x, dict; operator) while !isequal(x, y) @@ -425,6 +436,16 @@ function fixpoint_sub(x, dict; operator = Nothing) end const Eq = Union{Equation, Inequality} +""" + fast_substitute(expr, dict; operator = Nothing) + +Given a symbolic expression, equation or inequality `expr` perform the substitutions in +`dict`. This only performs the substitutions once. For example, +`fast_substitute(x, Dict(x => y, y => 3))` will return `y`. The `operator` keyword can be +specified to prevent substitution of expressions inside operators of the given type. + +See also: [`fixpoint_sub`](@ref). +""" function fast_substitute(eq::Eq, subs; operator = Nothing) if eq isa Inequality Inequality(fast_substitute(eq.lhs, subs; operator),