From 7713bc0bc29b8e6abfe7b3eb8b7ef60d4832d054 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 11 Mar 2024 13:00:35 +0530 Subject: [PATCH] feat: add `symbolic_evaluate` to interface, implement `observed` for `SymbolCache` --- Project.toml | 6 +++ docs/src/api.md | 2 + docs/src/complete_sii.md | 31 +++++++++++++++ src/SymbolicIndexingInterface.jl | 6 ++- src/interface.jl | 4 +- src/symbol_cache.jl | 67 ++++++++++++++++++++++++++------ src/trait.jl | 27 +++++++++++++ test/symbol_cache_test.jl | 20 +++++++++- 8 files changed, 147 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index 8d17bf91..8d03a60d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,8 +3,14 @@ uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" authors = ["Aayush Sabharwal and contributors"] version = "0.3.10" +[deps] +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" + [compat] Aqua = "0.8" +MacroTools = "0.5.13" +RuntimeGeneratedFunctions = "0.5" SafeTestsets = "0.0.1" Test = "1" julia = "1.10" diff --git a/docs/src/api.md b/docs/src/api.md index 47ec0750..b297681f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -13,6 +13,7 @@ parameter_symbols is_independent_variable independent_variable_symbols is_observed +default_values is_time_dependent constant_structure all_variable_symbols @@ -73,6 +74,7 @@ NotSymbolic symbolic_type hasname getname +symbolic_evaluate ``` # Types diff --git a/docs/src/complete_sii.md b/docs/src/complete_sii.md index 65d6e723..a796a9c8 100644 --- a/docs/src/complete_sii.md +++ b/docs/src/complete_sii.md @@ -344,6 +344,37 @@ end [`hasname`](@ref) is not required to always be `true` for symbolic types. For example, `Symbolics.Num` returns `false` whenever the wrapped value is a number, or an expression. +Introducing a type to represent expression trees: + +```julia +struct MyExpr + op::Function + args::Vector{Union{MyExpr, MySym, MySymArr, Number, Array}} +end +``` + +[`symbolic_evaluate`](@ref) can be implemented as follows: + +```julia +function symbolic_evaluate(expr::Union{MySym, MySymArr}, syms::Dict) + get(syms, expr, expr) +end +function symbolic_evaluate(expr::MyExpr, syms::Dict) + for i in eachindex(expr.args) + if expr.args[i] isa Union{MySym, MySymArr, MyExpr} + expr.args[i] = symbolic_evaluate(expr.args[i], syms) + end + end + if all(x -> symbolic_type(x) === NotSymbolic(), expr.args) + return expr.op(expr.args...) + end +end +``` + +Note the evaluation of the operation if all of the arguments are not symbolic. This is +required since `symbolic_evaluate` must return an evaluated value if all symbolic variables +are substituted. + ## Parameter Timeseries If a solution object saves modified parameter values (such as through callbacks) during the diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index cc52da04..0d543edf 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -1,5 +1,9 @@ module SymbolicIndexingInterface +import MacroTools +using RuntimeGeneratedFunctions +RuntimeGeneratedFunctions.init(@__MODULE__) + export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname, Timeseries, NotTimeseries, is_timeseries include("trait.jl") @@ -9,7 +13,7 @@ export is_variable, variable_index, variable_symbols, is_parameter, parameter_in is_observed, observed, is_time_dependent, constant_structure, symbolic_container, all_variable_symbols, - all_symbols, solvedvariables, allvariables, default_values + all_symbols, solvedvariables, allvariables, default_values, symbolic_evaluate include("interface.jl") export SymbolCache diff --git a/src/interface.jl b/src/interface.jl index e7605cdf..76ea1d7b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -137,8 +137,8 @@ all_symbols(sys) = all_symbols(symbolic_container(sys)) """ default_values(sys) -Return a dictionary mapping symbols in the system to their default value, if any. This includes -parameter symbols. +Return a dictionary mapping symbols in the system to their default value, if any. This +includes parameter symbols. The dictionary must be mutable. """ default_values(sys) = default_values(symbolic_container(sys)) diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index b923e1ac..0db5e927 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -2,11 +2,12 @@ struct SymbolCache{V,P,I} function SymbolCache(vars, [params, [indepvars]]) -A struct implementing the symbolic indexing interface for the trivial case -of having a vector of variables, parameters, and independent variables. This -struct does not implement `observed`, and `is_observed` returns `false` for -all input symbols. It is considered time dependent if it contains -at least one independent variable. +A struct implementing the symbolic indexing interface for the trivial case of having a +vector of variables, parameters, and independent variables. It is considered time +dependent if it contains at least one independent variable. It returns `true` for +`is_observed(::SymbolCache, sym)` if `sym isa Expr`. Functions can be generated using +`observed` for `Expr`s involving variables in the `SymbolCache` if it has at most one +independent variable. The independent variable may be specified as a single symbolic variable instead of an array containing a single variable if the system has only one independent variable. @@ -14,17 +15,22 @@ array containing a single variable if the system has only one independent variab struct SymbolCache{ V <: Union{Nothing, AbstractVector}, P <: Union{Nothing, AbstractVector}, - I + I, + D <: Dict } variables::V parameters::P independent_variables::I + defaults::D end -function SymbolCache(vars = nothing, params = nothing, indepvars = nothing) - return SymbolCache{typeof(vars), typeof(params), typeof(indepvars)}(vars, +function SymbolCache(vars = nothing, params = nothing, indepvars = nothing; + defaults = Dict{Symbol, Union{Symbol, Expr, Number}}()) + return SymbolCache{typeof(vars), typeof(params), typeof(indepvars), typeof(defaults)}( + vars, params, - indepvars) + indepvars, + defaults) end function is_variable(sc::SymbolCache, sym) @@ -62,6 +68,45 @@ function independent_variable_symbols(sc::SymbolCache) end end is_observed(sc::SymbolCache, sym) = false +is_observed(::SymbolCache, ::Expr) = true +function observed(sc::SymbolCache, expr::Expr) + let cache = Dict{Expr, Function}() + return get!(cache, expr) do + fnbody = Expr(:block) + declared = Set{Symbol}() + MacroTools.postwalk(expr) do sym + sym isa Symbol || return + sym in declared && return + if sc.variables !== nothing && + (idx = findfirst(isequal(sym), sc.variables)) !== nothing + push!(fnbody.args, :($sym = u[$idx])) + push!(declared, sym) + elseif sc.parameters !== nothing && + (idx = findfirst(isequal(sym), sc.parameters)) !== nothing + push!(fnbody.args, :($sym = p[$idx])) + push!(declared, sym) + elseif sym === sc.independent_variables || + sc.independent_variables isa Vector && + sym == only(sc.independent_variables) + push!(fnbody.args, :($sym = t)) + push!(declared, sym) + end + end + fnexpr = if is_time_dependent(sc) + :(function (u, p, t) + $fnbody + return $expr + end) + else + :(function (u, p) + $fnbody + return $expr + end) + end + return RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(fnexpr) + end + end +end function is_time_dependent(sc::SymbolCache) sc.independent_variables === nothing && return false if symbolic_type(sc.independent_variables) == NotSymbolic() @@ -75,11 +120,11 @@ all_variable_symbols(sc::SymbolCache) = variable_symbols(sc) function all_symbols(sc::SymbolCache) vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc)) end -default_values(::SymbolCache) = Dict() +default_values(sc::SymbolCache) = sc.defaults function Base.copy(sc::SymbolCache) return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables), sc.parameters === nothing ? nothing : copy(sc.parameters), sc.independent_variables isa AbstractArray ? copy(sc.independent_variables) : - sc.independent_variables) + sc.independent_variables, copy(sc.defaults)) end diff --git a/src/trait.jl b/src/trait.jl index 29c6673f..f2e3a297 100644 --- a/src/trait.jl +++ b/src/trait.jl @@ -61,6 +61,33 @@ Get the name of a symbolic variable as a `Symbol` """ function getname end +""" + symbolic_evaluate(expr, syms::Dict) + +Return the value of symbolic expression `expr` where the values of variables involved are +obtained from the dictionary `syms`. The keys of `syms` are symbolic variables (not +expressions of variables). The values of `syms` can be values or symbolic +expressions. + +The returned value should either be a value or an expression involving symbolic variables +not present as keys in `syms`. + +This is already implemented for +`symbolic_evaluate(expr::Union{Symbol, Expr}, syms::Dict{Symbol})`. +""" +function symbolic_evaluate(expr::Union{Symbol, Expr}, syms::Dict{Symbol}) + while (new_expr = MacroTools.postwalk(expr) do sym + return get(syms, sym, sym) + end) != expr + expr = new_expr + end + return try + eval(expr) + catch + expr + end +end + ############ IsTimeseriesTrait abstract type IsTimeseriesTrait end diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index 5f57b120..6a1ffa90 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -1,7 +1,8 @@ using SymbolicIndexingInterface using Test -sc = SymbolCache([:x, :y, :z], [:a, :b], [:t]) +sc = SymbolCache( + [:x, :y, :z], [:a, :b], [:t]; defaults = Dict(:x => 1, :y => :(2b), :b => :(2a + x))) @test all(is_variable.((sc,), [:x, :y, :z])) @test all(.!is_variable.((sc,), [:a, :b, :t, :q])) @@ -19,11 +20,26 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t]) @test independent_variable_symbols(sc) == [:t] @test all_variable_symbols(sc) == [:x, :y, :z] @test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z] -@test isempty(default_values(sc)) +@test default_values(sc)[:x] == 1 +@test default_values(sc)[:y] == :(2b) +@test default_values(sc)[:b] == :(2a + x) + +@test symbolic_evaluate(:x, default_values(sc)) == 1 +@test symbolic_evaluate(:y, default_values(sc)) == :(2 * (2a + 1)) +@test symbolic_evaluate(:(x + y), merge(default_values(sc), Dict(:a => 2))) == 11 + +@test is_observed(sc, :(x + a + t)) +obsfn = observed(sc, :(x + a + t)) +@test obsfn(ones(3), 2ones(2), 3.0) == 6.0 +obsfn2 = observed(sc, :(x + a + t)) +@test obsfn === obsfn2 sc = SymbolCache([:x, :y], [:a, :b]) @test !is_time_dependent(sc) @test sort(all_symbols(sc)) == [:a, :b, :x, :y] +@test is_observed(sc, :(x + b)) +obsfn = observed(sc, :(x + b)) +@test obsfn(ones(2), 2ones(2)) == 3.0 # make sure the constructor works @test_nowarn SymbolCache([:x, :y])