Skip to content

Commit

Permalink
feat: add symbolic_evaluate to interface, implement observed for …
Browse files Browse the repository at this point in the history
…`SymbolCache`
  • Loading branch information
AayushSabharwal committed Mar 11, 2024
1 parent 405b1a5 commit 7713bc0
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 16 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@ uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
authors = ["Aayush Sabharwal <[email protected]> and contributors"]
version = "0.3.10"

[deps]
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"

[compat]
Aqua = "0.8"
MacroTools = "0.5.13"
RuntimeGeneratedFunctions = "0.5"
SafeTestsets = "0.0.1"
Test = "1"
julia = "1.10"
Expand Down
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ parameter_symbols
is_independent_variable
independent_variable_symbols
is_observed
default_values
is_time_dependent
constant_structure
all_variable_symbols
Expand Down Expand Up @@ -73,6 +74,7 @@ NotSymbolic
symbolic_type
hasname
getname
symbolic_evaluate
```

# Types
Expand Down
31 changes: 31 additions & 0 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,37 @@ end
[`hasname`](@ref) is not required to always be `true` for symbolic types. For example,
`Symbolics.Num` returns `false` whenever the wrapped value is a number, or an expression.

Introducing a type to represent expression trees:

```julia
struct MyExpr
op::Function
args::Vector{Union{MyExpr, MySym, MySymArr, Number, Array}}
end
```

[`symbolic_evaluate`](@ref) can be implemented as follows:

```julia
function symbolic_evaluate(expr::Union{MySym, MySymArr}, syms::Dict)
get(syms, expr, expr)
end
function symbolic_evaluate(expr::MyExpr, syms::Dict)
for i in eachindex(expr.args)
if expr.args[i] isa Union{MySym, MySymArr, MyExpr}
expr.args[i] = symbolic_evaluate(expr.args[i], syms)
end
end
if all(x -> symbolic_type(x) === NotSymbolic(), expr.args)
return expr.op(expr.args...)
end
end
```

Note the evaluation of the operation if all of the arguments are not symbolic. This is
required since `symbolic_evaluate` must return an evaluated value if all symbolic variables
are substituted.

## Parameter Timeseries

If a solution object saves modified parameter values (such as through callbacks) during the
Expand Down
6 changes: 5 additions & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
module SymbolicIndexingInterface

import MacroTools
using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)

export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname,
Timeseries, NotTimeseries, is_timeseries
include("trait.jl")
Expand All @@ -9,7 +13,7 @@ export is_variable, variable_index, variable_symbols, is_parameter, parameter_in
is_observed,
observed, is_time_dependent, constant_structure, symbolic_container,
all_variable_symbols,
all_symbols, solvedvariables, allvariables, default_values
all_symbols, solvedvariables, allvariables, default_values, symbolic_evaluate
include("interface.jl")

export SymbolCache
Expand Down
4 changes: 2 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ all_symbols(sys) = all_symbols(symbolic_container(sys))
"""
default_values(sys)
Return a dictionary mapping symbols in the system to their default value, if any. This includes
parameter symbols.
Return a dictionary mapping symbols in the system to their default value, if any. This
includes parameter symbols. The dictionary must be mutable.
"""
default_values(sys) = default_values(symbolic_container(sys))

Check warning on line 143 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L143

Added line #L143 was not covered by tests

Expand Down
67 changes: 56 additions & 11 deletions src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,35 @@
struct SymbolCache{V,P,I}
function SymbolCache(vars, [params, [indepvars]])
A struct implementing the symbolic indexing interface for the trivial case
of having a vector of variables, parameters, and independent variables. This
struct does not implement `observed`, and `is_observed` returns `false` for
all input symbols. It is considered time dependent if it contains
at least one independent variable.
A struct implementing the symbolic indexing interface for the trivial case of having a
vector of variables, parameters, and independent variables. It is considered time
dependent if it contains at least one independent variable. It returns `true` for
`is_observed(::SymbolCache, sym)` if `sym isa Expr`. Functions can be generated using
`observed` for `Expr`s involving variables in the `SymbolCache` if it has at most one
independent variable.
The independent variable may be specified as a single symbolic variable instead of an
array containing a single variable if the system has only one independent variable.
"""
struct SymbolCache{
V <: Union{Nothing, AbstractVector},
P <: Union{Nothing, AbstractVector},
I
I,
D <: Dict
}
variables::V
parameters::P
independent_variables::I
defaults::D
end

function SymbolCache(vars = nothing, params = nothing, indepvars = nothing)
return SymbolCache{typeof(vars), typeof(params), typeof(indepvars)}(vars,
function SymbolCache(vars = nothing, params = nothing, indepvars = nothing;
defaults = Dict{Symbol, Union{Symbol, Expr, Number}}())
return SymbolCache{typeof(vars), typeof(params), typeof(indepvars), typeof(defaults)}(
vars,
params,
indepvars)
indepvars,
defaults)
end

function is_variable(sc::SymbolCache, sym)
Expand Down Expand Up @@ -62,6 +68,45 @@ function independent_variable_symbols(sc::SymbolCache)
end
end
is_observed(sc::SymbolCache, sym) = false
is_observed(::SymbolCache, ::Expr) = true
function observed(sc::SymbolCache, expr::Expr)
let cache = Dict{Expr, Function}()
return get!(cache, expr) do
fnbody = Expr(:block)
declared = Set{Symbol}()
MacroTools.postwalk(expr) do sym
sym isa Symbol || return
sym in declared && return
if sc.variables !== nothing &&

Check warning on line 80 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L71-L80

Added lines #L71 - L80 were not covered by tests
(idx = findfirst(isequal(sym), sc.variables)) !== nothing
push!(fnbody.args, :($sym = u[$idx]))
push!(declared, sym)
elseif sc.parameters !== nothing &&

Check warning on line 84 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L82-L84

Added lines #L82 - L84 were not covered by tests
(idx = findfirst(isequal(sym), sc.parameters)) !== nothing
push!(fnbody.args, :($sym = p[$idx]))
push!(declared, sym)
elseif sym === sc.independent_variables ||

Check warning on line 88 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L86-L88

Added lines #L86 - L88 were not covered by tests
sc.independent_variables isa Vector &&
sym == only(sc.independent_variables)
push!(fnbody.args, :($sym = t))
push!(declared, sym)

Check warning on line 92 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L91-L92

Added lines #L91 - L92 were not covered by tests
end
end
fnexpr = if is_time_dependent(sc)
:(function (u, p, t)
$fnbody
return $expr

Check warning on line 98 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L95-L98

Added lines #L95 - L98 were not covered by tests
end)
else
:(function (u, p)
$fnbody
return $expr

Check warning on line 103 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L101-L103

Added lines #L101 - L103 were not covered by tests
end)
end
return RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(fnexpr)

Check warning on line 106 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L106

Added line #L106 was not covered by tests
end
end
end
function is_time_dependent(sc::SymbolCache)
sc.independent_variables === nothing && return false
if symbolic_type(sc.independent_variables) == NotSymbolic()
Expand All @@ -75,11 +120,11 @@ all_variable_symbols(sc::SymbolCache) = variable_symbols(sc)
function all_symbols(sc::SymbolCache)
vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc))
end
default_values(::SymbolCache) = Dict()
default_values(sc::SymbolCache) = sc.defaults

Check warning on line 123 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L123

Added line #L123 was not covered by tests

function Base.copy(sc::SymbolCache)
return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables),
sc.parameters === nothing ? nothing : copy(sc.parameters),
sc.independent_variables isa AbstractArray ? copy(sc.independent_variables) :
sc.independent_variables)
sc.independent_variables, copy(sc.defaults))
end
27 changes: 27 additions & 0 deletions src/trait.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,33 @@ Get the name of a symbolic variable as a `Symbol`
"""
function getname end

"""
symbolic_evaluate(expr, syms::Dict)
Return the value of symbolic expression `expr` where the values of variables involved are
obtained from the dictionary `syms`. The keys of `syms` are symbolic variables (not
expressions of variables). The values of `syms` can be values or symbolic
expressions.
The returned value should either be a value or an expression involving symbolic variables
not present as keys in `syms`.
This is already implemented for
`symbolic_evaluate(expr::Union{Symbol, Expr}, syms::Dict{Symbol})`.
"""
function symbolic_evaluate(expr::Union{Symbol, Expr}, syms::Dict{Symbol})
while (new_expr = MacroTools.postwalk(expr) do sym
return get(syms, sym, sym)

Check warning on line 80 in src/trait.jl

View check run for this annotation

Codecov / codecov/patch

src/trait.jl#L78-L80

Added lines #L78 - L80 were not covered by tests
end) != expr
expr = new_expr
end
return try
eval(expr)

Check warning on line 85 in src/trait.jl

View check run for this annotation

Codecov / codecov/patch

src/trait.jl#L82-L85

Added lines #L82 - L85 were not covered by tests
catch
expr

Check warning on line 87 in src/trait.jl

View check run for this annotation

Codecov / codecov/patch

src/trait.jl#L87

Added line #L87 was not covered by tests
end
end

############ IsTimeseriesTrait

abstract type IsTimeseriesTrait end
Expand Down
20 changes: 18 additions & 2 deletions test/symbol_cache_test.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using SymbolicIndexingInterface
using Test

sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])
sc = SymbolCache(
[:x, :y, :z], [:a, :b], [:t]; defaults = Dict(:x => 1, :y => :(2b), :b => :(2a + x)))

@test all(is_variable.((sc,), [:x, :y, :z]))
@test all(.!is_variable.((sc,), [:a, :b, :t, :q]))
Expand All @@ -19,11 +20,26 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])
@test independent_variable_symbols(sc) == [:t]
@test all_variable_symbols(sc) == [:x, :y, :z]
@test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z]
@test isempty(default_values(sc))
@test default_values(sc)[:x] == 1
@test default_values(sc)[:y] == :(2b)
@test default_values(sc)[:b] == :(2a + x)

@test symbolic_evaluate(:x, default_values(sc)) == 1
@test symbolic_evaluate(:y, default_values(sc)) == :(2 * (2a + 1))
@test symbolic_evaluate(:(x + y), merge(default_values(sc), Dict(:a => 2))) == 11

@test is_observed(sc, :(x + a + t))
obsfn = observed(sc, :(x + a + t))
@test obsfn(ones(3), 2ones(2), 3.0) == 6.0
obsfn2 = observed(sc, :(x + a + t))
@test obsfn === obsfn2

sc = SymbolCache([:x, :y], [:a, :b])
@test !is_time_dependent(sc)
@test sort(all_symbols(sc)) == [:a, :b, :x, :y]
@test is_observed(sc, :(x + b))
obsfn = observed(sc, :(x + b))
@test obsfn(ones(2), 2ones(2)) == 3.0
# make sure the constructor works
@test_nowarn SymbolCache([:x, :y])

Expand Down

0 comments on commit 7713bc0

Please sign in to comment.