Skip to content

Commit

Permalink
feat: move fast_substitute to Symbolics, implement SII.symbolic_evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Mar 25, 2024
1 parent c6cf755 commit 2be328e
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ SciMLBase = "2"
Setfield = "1"
SpecialFunctions = "2"
StaticArrays = "1.1"
SymbolicIndexingInterface = "0.3"
SymbolicIndexingInterface = "0.3.12"
SymbolicLimits = "0.2.0"
SymbolicUtils = "1.4"
julia = "1.10"
Expand Down
92 changes: 92 additions & 0 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,98 @@ end

SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val)

function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symbolic, Equation, Inequality}, d::Dict; kwargs...)
fixpoint_sub(ex, d; kwargs...)
end

function fixpoint_sub(x, dict; operator = Nothing)
y = fast_substitute(x, dict; operator)
while !isequal(x, y)
y = x
x = fast_substitute(y, dict; operator)
end

return x
end

const Eq = Union{Equation, Inequality}
function fast_substitute(eq::Eq, subs; operator = Nothing)
if eq isa Inequality
Inequality(fast_substitute(eq.lhs, subs; operator),
fast_substitute(eq.rhs, subs; operator),
eq.relational_op)
else
Equation(fast_substitute(eq.lhs, subs; operator),
fast_substitute(eq.rhs, subs; operator))
end
end
function fast_substitute(eq::T, subs::Pair; operator = Nothing) where {T <: Eq}
T(fast_substitute(eq.lhs, subs; operator), fast_substitute(eq.rhs, subs; operator))
end
function fast_substitute(eqs::AbstractArray, subs; operator = Nothing)
fast_substitute.(eqs, (subs,); operator)
end
function fast_substitute(eqs::AbstractArray, subs::Pair; operator = Nothing)
fast_substitute.(eqs, (subs,); operator)
end
for (exprType, subsType) in Iterators.product((Num, Symbolics.Arr), (Any, Pair))
@eval function fast_substitute(expr::$exprType, subs::$subsType; operator = Nothing)
fast_substitute(value(expr), subs; operator)
end
end
function fast_substitute(expr, subs; operator = Nothing)
if (_val = get(subs, expr, nothing)) !== nothing
return _val
end
istree(expr) || return expr
op = fast_substitute(operation(expr), subs; operator)
args = SymbolicUtils.unsorted_arguments(expr)
if !(op isa operator)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
map(args) do x
x′ = fast_substitute(x, subs; operator)
canfold[] = canfold[] && !(x′ isa Symbolic)
x′
end
end
canfold[] && return op(args...)
end
similarterm(expr,
op,
args,
symtype(expr);
metadata = metadata(expr))
end
function fast_substitute(expr, pair::Pair; operator = Nothing)
a, b = pair
isequal(expr, a) && return b
if a isa AbstractArray
for (ai, bi) in zip(a, b)
expr = fast_substitute(expr, ai => bi; operator)
end
end
istree(expr) || return expr
op = fast_substitute(operation(expr), pair; operator)
args = SymbolicUtils.unsorted_arguments(expr)
if !(op isa operator)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
map(args) do x
x′ = fast_substitute(x, pair; operator)
canfold[] = canfold[] && !(x′ isa Symbolic)
x′
end
end
canfold[] && return op(args...)
end
similarterm(expr,
op,
args,
symtype(expr);
metadata = metadata(expr))
end

function getparent(x, val=_fail)
maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing)
if maybe_parent !== nothing
Expand Down
8 changes: 8 additions & 0 deletions test/symbolic_indexing_interface_trait.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,11 @@ using SymbolicIndexingInterface
@variables y[1:3]
@test symbolic_type(y) == ArraySymbolic()
@test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),))

@variables x y z
subs = Dict(x => 0.1, y => 2z)
subs2 = merge(subs, Dict(z => 2x+3))

@test symbolic_evaluate(x, subs) == 0.1
@test isequal(symbolic_evaluate(y, subs), 2z)
@test symbolic_evaluate(y, subs2) == 6.4

0 comments on commit 2be328e

Please sign in to comment.