From 783ab25d98b7620d12051da5c2084370fb24e9e9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 12 Mar 2024 11:15:59 +0530 Subject: [PATCH] feat: move fast_substitute to Symbolics, implement SII.symbolic_evaluate --- Project.toml | 2 +- src/variable.jl | 92 +++++++++++++++++++++++ test/symbolic_indexing_interface_trait.jl | 8 ++ 3 files changed, 101 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5f4228de6..b8e0e3a42 100644 --- a/Project.toml +++ b/Project.toml @@ -81,7 +81,7 @@ SciMLBase = "2" Setfield = "1" SpecialFunctions = "2" StaticArrays = "1.1" -SymbolicIndexingInterface = "0.3" +SymbolicIndexingInterface = "0.3.11" SymbolicLimits = "0.2.0" SymbolicUtils = "1.4" julia = "1.10" diff --git a/src/variable.jl b/src/variable.jl index aba9ba502..259d23c1e 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -410,6 +410,98 @@ end SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val) +function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symbolic}, d::Dict; kwargs...) + fixpoint_sub(ex, d; kwargs...) +end + +function fixpoint_sub(x, dict; operator = Nothing) + y = fast_substitute(x, dict; operator) + while !isequal(x, y) + y = x + x = fast_substitute(y, dict; operator) + end + + return x +end + +const Eq = Union{Equation, Inequality} +function fast_substitute(eq::Eq, subs; operator = Nothing) + if eq isa Inequality + Inequality(fast_substitute(eq.lhs, subs; operator), + fast_substitute(eq.rhs, subs; operator), + eq.relational_op) + else + Equation(fast_substitute(eq.lhs, subs; operator), + fast_substitute(eq.rhs, subs; operator)) + end +end +function fast_substitute(eq::T, subs::Pair; operator = Nothing) where {T <: Eq} + T(fast_substitute(eq.lhs, subs; operator), fast_substitute(eq.rhs, subs; operator)) +end +function fast_substitute(eqs::AbstractArray, subs; operator = Nothing) + fast_substitute.(eqs, (subs,); operator) +end +function fast_substitute(eqs::AbstractArray, subs::Pair; operator = Nothing) + fast_substitute.(eqs, (subs,); operator) +end +for (exprType, subsType) in Iterators.product((Num, Symbolics.Arr), (Any, Pair)) + @eval function fast_substitute(expr::$exprType, subs::$subsType; operator = Nothing) + fast_substitute(value(expr), subs; operator) + end +end +function fast_substitute(expr, subs; operator = Nothing) + if (_val = get(subs, expr, nothing)) !== nothing + return _val + end + istree(expr) || return expr + op = fast_substitute(operation(expr), subs; operator) + args = SymbolicUtils.unsorted_arguments(expr) + if !(op isa operator) + canfold = Ref(!(op isa Symbolic)) + args = let canfold = canfold + map(args) do x + x′ = fast_substitute(x, subs; operator) + canfold[] = canfold[] && !(x′ isa Symbolic) + x′ + end + end + canfold[] && return op(args...) + end + similarterm(expr, + op, + args, + symtype(expr); + metadata = metadata(expr)) +end +function fast_substitute(expr, pair::Pair; operator = Nothing) + a, b = pair + isequal(expr, a) && return b + if a isa AbstractArray + for (ai, bi) in zip(a, b) + expr = fast_substitute(expr, ai => bi; operator) + end + end + istree(expr) || return expr + op = fast_substitute(operation(expr), pair; operator) + args = SymbolicUtils.unsorted_arguments(expr) + if !(op isa operator) + canfold = Ref(!(op isa Symbolic)) + args = let canfold = canfold + map(args) do x + x′ = fast_substitute(x, pair; operator) + canfold[] = canfold[] && !(x′ isa Symbolic) + x′ + end + end + canfold[] && return op(args...) + end + similarterm(expr, + op, + args, + symtype(expr); + metadata = metadata(expr)) +end + function getparent(x, val=_fail) maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing) if maybe_parent !== nothing diff --git a/test/symbolic_indexing_interface_trait.jl b/test/symbolic_indexing_interface_trait.jl index 52d1579ae..6c8af1457 100644 --- a/test/symbolic_indexing_interface_trait.jl +++ b/test/symbolic_indexing_interface_trait.jl @@ -10,3 +10,11 @@ using SymbolicIndexingInterface @variables y[1:3] @test symbolic_type(y) == ArraySymbolic() @test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),)) + +@variables x y z +subs = Dict(x => 0.1, y => 2z) +subs2 = merge(subs, Dict(z => 2x+3)) + +@test symbolic_evaluate(x, subs) == 0.1 +@test isequal(symbolic_evaluate(y, subs), 2z) +@test symbolic_evaluate(y, subs2) == 6.4