Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add default_values to interface #47

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@ uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
authors = ["Aayush Sabharwal <[email protected]> and contributors"]
version = "0.3.10"

[deps]
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"

[compat]
Aqua = "0.8"
MacroTools = "0.5.13"
RuntimeGeneratedFunctions = "0.5"
SafeTestsets = "0.0.1"
Test = "1"
julia = "1.10"
Expand Down
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ parameter_symbols
is_independent_variable
independent_variable_symbols
is_observed
default_values
is_time_dependent
constant_structure
all_variable_symbols
Expand Down Expand Up @@ -73,6 +74,7 @@ NotSymbolic
symbolic_type
hasname
getname
symbolic_evaluate
```

# Types
Expand Down
36 changes: 36 additions & 0 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ struct ExampleSystem
state_index::Dict{Symbol,Int}
parameter_index::Dict{Symbol,Int}
independent_variable::Union{Symbol,Nothing}
defaults::Dict{Symbol, Float64}
# mapping from observed variable to Expr to calculate its value
observed::Dict{Symbol,Expr}
end
Expand Down Expand Up @@ -77,6 +78,10 @@ function SymbolicIndexingInterface.all_symbols(sys::ExampleSystem)
sys.independent_variable === nothing ? Symbol[] : sys.independent_variable
)
end

function SymbolicIndexingInterface.default_values(sys::ExampleSystem)
return sys.defaults
end
```

### Observed Equation Handling
Expand Down Expand Up @@ -339,6 +344,37 @@ end
[`hasname`](@ref) is not required to always be `true` for symbolic types. For example,
`Symbolics.Num` returns `false` whenever the wrapped value is a number, or an expression.

Introducing a type to represent expression trees:

```julia
struct MyExpr
op::Function
args::Vector{Union{MyExpr, MySym, MySymArr, Number, Array}}
end
```

[`symbolic_evaluate`](@ref) can be implemented as follows:

```julia
function symbolic_evaluate(expr::Union{MySym, MySymArr}, syms::Dict)
get(syms, expr, expr)
end
function symbolic_evaluate(expr::MyExpr, syms::Dict)
for i in eachindex(expr.args)
if expr.args[i] isa Union{MySym, MySymArr, MyExpr}
expr.args[i] = symbolic_evaluate(expr.args[i], syms)
end
end
if all(x -> symbolic_type(x) === NotSymbolic(), expr.args)
return expr.op(expr.args...)
end
end
```

Note the evaluation of the operation if all of the arguments are not symbolic. This is
required since `symbolic_evaluate` must return an evaluated value if all symbolic variables
are substituted.

## Parameter Timeseries

If a solution object saves modified parameter values (such as through callbacks) during the
Expand Down
6 changes: 5 additions & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
module SymbolicIndexingInterface

import MacroTools
using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)

export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname,
Timeseries, NotTimeseries, is_timeseries
include("trait.jl")
Expand All @@ -9,7 +13,7 @@ export is_variable, variable_index, variable_symbols, is_parameter, parameter_in
is_observed,
observed, is_time_dependent, constant_structure, symbolic_container,
all_variable_symbols,
all_symbols, solvedvariables, allvariables
all_symbols, solvedvariables, allvariables, default_values, symbolic_evaluate
include("interface.jl")

export SymbolCache
Expand Down
8 changes: 8 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@
"""
all_symbols(sys) = all_symbols(symbolic_container(sys))

"""
default_values(sys)

Return a dictionary mapping symbols in the system to their default value, if any. This
includes parameter symbols. The dictionary must be mutable.
"""
default_values(sys) = default_values(symbolic_container(sys))

Check warning on line 143 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L143

Added line #L143 was not covered by tests

struct SolvedVariables end

"""
Expand Down
66 changes: 56 additions & 10 deletions src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,35 @@
struct SymbolCache{V,P,I}
function SymbolCache(vars, [params, [indepvars]])

A struct implementing the symbolic indexing interface for the trivial case
of having a vector of variables, parameters, and independent variables. This
struct does not implement `observed`, and `is_observed` returns `false` for
all input symbols. It is considered time dependent if it contains
at least one independent variable.
A struct implementing the symbolic indexing interface for the trivial case of having a
vector of variables, parameters, and independent variables. It is considered time
dependent if it contains at least one independent variable. It returns `true` for
`is_observed(::SymbolCache, sym)` if `sym isa Expr`. Functions can be generated using
`observed` for `Expr`s involving variables in the `SymbolCache` if it has at most one
independent variable.

The independent variable may be specified as a single symbolic variable instead of an
array containing a single variable if the system has only one independent variable.
"""
struct SymbolCache{
V <: Union{Nothing, AbstractVector},
P <: Union{Nothing, AbstractVector},
I
I,
D <: Dict
}
variables::V
parameters::P
independent_variables::I
defaults::D
end

function SymbolCache(vars = nothing, params = nothing, indepvars = nothing)
return SymbolCache{typeof(vars), typeof(params), typeof(indepvars)}(vars,
function SymbolCache(vars = nothing, params = nothing, indepvars = nothing;
defaults = Dict{Symbol, Union{Symbol, Expr, Number}}())
return SymbolCache{typeof(vars), typeof(params), typeof(indepvars), typeof(defaults)}(
vars,
params,
indepvars)
indepvars,
defaults)
end

function is_variable(sc::SymbolCache, sym)
Expand Down Expand Up @@ -62,6 +68,45 @@
end
end
is_observed(sc::SymbolCache, sym) = false
is_observed(::SymbolCache, ::Expr) = true
function observed(sc::SymbolCache, expr::Expr)
let cache = Dict{Expr, Function}()
return get!(cache, expr) do
fnbody = Expr(:block)
declared = Set{Symbol}()
MacroTools.postwalk(expr) do sym
sym isa Symbol || return
sym in declared && return
if sc.variables !== nothing &&

Check warning on line 80 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L71-L80

Added lines #L71 - L80 were not covered by tests
(idx = findfirst(isequal(sym), sc.variables)) !== nothing
push!(fnbody.args, :($sym = u[$idx]))
push!(declared, sym)
elseif sc.parameters !== nothing &&

Check warning on line 84 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L82-L84

Added lines #L82 - L84 were not covered by tests
(idx = findfirst(isequal(sym), sc.parameters)) !== nothing
push!(fnbody.args, :($sym = p[$idx]))
push!(declared, sym)
elseif sym === sc.independent_variables ||

Check warning on line 88 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L86-L88

Added lines #L86 - L88 were not covered by tests
sc.independent_variables isa Vector &&
sym == only(sc.independent_variables)
push!(fnbody.args, :($sym = t))
push!(declared, sym)

Check warning on line 92 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L91-L92

Added lines #L91 - L92 were not covered by tests
end
end
fnexpr = if is_time_dependent(sc)
:(function (u, p, t)
$fnbody
return $expr

Check warning on line 98 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L95-L98

Added lines #L95 - L98 were not covered by tests
end)
else
:(function (u, p)
$fnbody
return $expr

Check warning on line 103 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L101-L103

Added lines #L101 - L103 were not covered by tests
end)
end
return RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(fnexpr)

Check warning on line 106 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L106

Added line #L106 was not covered by tests
end
end
end
function is_time_dependent(sc::SymbolCache)
sc.independent_variables === nothing && return false
if symbolic_type(sc.independent_variables) == NotSymbolic()
Expand All @@ -75,10 +120,11 @@
function all_symbols(sc::SymbolCache)
vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc))
end
default_values(sc::SymbolCache) = sc.defaults

Check warning on line 123 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L123

Added line #L123 was not covered by tests

function Base.copy(sc::SymbolCache)
return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables),
sc.parameters === nothing ? nothing : copy(sc.parameters),
sc.independent_variables isa AbstractArray ? copy(sc.independent_variables) :
sc.independent_variables)
sc.independent_variables, copy(sc.defaults))
end
27 changes: 27 additions & 0 deletions src/trait.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,33 @@
"""
function getname end

"""
symbolic_evaluate(expr, syms::Dict)

Return the value of symbolic expression `expr` where the values of variables involved are
obtained from the dictionary `syms`. The keys of `syms` are symbolic variables (not
expressions of variables). The values of `syms` can be values or symbolic
expressions.

The returned value should either be a value or an expression involving symbolic variables
not present as keys in `syms`.

This is already implemented for
`symbolic_evaluate(expr::Union{Symbol, Expr}, syms::Dict{Symbol})`.
"""
function symbolic_evaluate(expr::Union{Symbol, Expr}, syms::Dict{Symbol})
while (new_expr = MacroTools.postwalk(expr) do sym
return get(syms, sym, sym)

Check warning on line 80 in src/trait.jl

View check run for this annotation

Codecov / codecov/patch

src/trait.jl#L78-L80

Added lines #L78 - L80 were not covered by tests
end) != expr
expr = new_expr
end
return try
eval(expr)

Check warning on line 85 in src/trait.jl

View check run for this annotation

Codecov / codecov/patch

src/trait.jl#L82-L85

Added lines #L82 - L85 were not covered by tests
catch
expr

Check warning on line 87 in src/trait.jl

View check run for this annotation

Codecov / codecov/patch

src/trait.jl#L87

Added line #L87 was not covered by tests
end
end

############ IsTimeseriesTrait

abstract type IsTimeseriesTrait end
Expand Down
21 changes: 20 additions & 1 deletion test/symbol_cache_test.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using SymbolicIndexingInterface
using Test

sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])
sc = SymbolCache(
[:x, :y, :z], [:a, :b], [:t]; defaults = Dict(:x => 1, :y => :(2b), :b => :(2a + x)))

@test all(is_variable.((sc,), [:x, :y, :z]))
@test all(.!is_variable.((sc,), [:a, :b, :t, :q]))
Expand All @@ -19,10 +20,26 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])
@test independent_variable_symbols(sc) == [:t]
@test all_variable_symbols(sc) == [:x, :y, :z]
@test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z]
@test default_values(sc)[:x] == 1
@test default_values(sc)[:y] == :(2b)
@test default_values(sc)[:b] == :(2a + x)

@test symbolic_evaluate(:x, default_values(sc)) == 1
@test symbolic_evaluate(:y, default_values(sc)) == :(2 * (2a + 1))
@test symbolic_evaluate(:(x + y), merge(default_values(sc), Dict(:a => 2))) == 11

@test is_observed(sc, :(x + a + t))
obsfn = observed(sc, :(x + a + t))
@test obsfn(ones(3), 2ones(2), 3.0) == 6.0
obsfn2 = observed(sc, :(x + a + t))
@test obsfn === obsfn2

sc = SymbolCache([:x, :y], [:a, :b])
@test !is_time_dependent(sc)
@test sort(all_symbols(sc)) == [:a, :b, :x, :y]
@test is_observed(sc, :(x + b))
obsfn = observed(sc, :(x + b))
@test obsfn(ones(2), 2ones(2)) == 3.0
# make sure the constructor works
@test_nowarn SymbolCache([:x, :y])

Expand All @@ -38,6 +55,7 @@ sc = SymbolCache()
@test !is_time_dependent(sc)
@test all_variable_symbols(sc) == []
@test all_symbols(sc) == []
@test isempty(default_values(sc))

sc = SymbolCache(nothing, nothing, :t)
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b]))
Expand All @@ -46,6 +64,7 @@ sc = SymbolCache(nothing, nothing, :t)
@test is_time_dependent(sc)
@test all_variable_symbols(sc) == []
@test all_symbols(sc) == [:t]
@test isempty(default_values(sc))

sc2 = copy(sc)
@test sc.variables == sc2.variables
Expand Down
Loading