Skip to content

Commit

Permalink
feat: move fast_substitute to Symbolics, implement SII.symbolic_evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Mar 12, 2024
1 parent f39e633 commit 078928d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ SciMLBase = "2"
Setfield = "1"
SpecialFunctions = "2"
StaticArrays = "1.1"
SymbolicIndexingInterface = "0.3"
SymbolicIndexingInterface = "0.3.11"
SymbolicUtils = "1.4"
julia = "1.10"

Expand Down
52 changes: 52 additions & 0 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,58 @@ end

SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val)

function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symbolic}, d::Dict)
fixpoint_sub(ex, d)
end

function fixpoint_sub(x, dict)
y = fast_substitute(x, dict)
while !isequal(x, y)
y = x
x = fast_substitute(y, dict)
end

return x
end

const Eq = Union{Equation, Inequality}
# substitute without unwrapping
function fast_substitute(eq::Eq, subs)
if eq isa Inequality
Inequality(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs),
eq.relational_op)
else
Equation(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs))
end
end
function fast_substitute(eq::T, subs::Pair) where {T <: Eq}
T(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs))
end
fast_substitute(eqs::AbstractArray, subs) = fast_substitute.(eqs, (subs,))
fast_substitute(a, b) = substitute(a, b)
function fast_substitute(expr, pair::Pair)
a, b = pair
isequal(expr, a) && return b

istree(expr) || return expr
op = fast_substitute(operation(expr), pair)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
map(SymbolicUtils.unsorted_arguments(expr)) do x
x′ = fast_substitute(x, pair)
canfold[] = canfold[] && !(x′ isa Symbolic)
x′
end
end
canfold[] && return op(args...)

similarterm(expr,
op,
args,
symtype(expr);
metadata = metadata(expr))
end

function getparent(x, val=_fail)
maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing)
if maybe_parent !== nothing
Expand Down
8 changes: 8 additions & 0 deletions test/symbolic_indexing_interface_trait.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,11 @@ using SymbolicIndexingInterface
@variables y[1:3]
@test symbolic_type(y) == ArraySymbolic()
@test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),))

@variables x y z
subs = Dict(x => 0.1, y => 2z)
subs2 = merge(subs, Dict(z => 2x+3))

@test symbolic_evaluate(x, subs) == 0.1
@test isequal(symbolic_evaluate(y, subs), 2z)
@test symbolic_evaluate(y, subs2) == 6.4

0 comments on commit 078928d

Please sign in to comment.