From 0e6e9e794a261101abe5d6b2eff54bbf59244e2c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Oct 2023 13:19:16 +0530 Subject: [PATCH 01/35] feat!: rework interface --- Project.toml | 4 - src/SymbolicIndexingInterface.jl | 73 ++++++++++++-- src/interface.jl | 162 ------------------------------- src/symbolcache.jl | 26 ----- 4 files changed, 66 insertions(+), 199 deletions(-) delete mode 100644 src/interface.jl delete mode 100644 src/symbolcache.jl diff --git a/Project.toml b/Project.toml index dc81671..676e622 100644 --- a/Project.toml +++ b/Project.toml @@ -3,11 +3,7 @@ uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" authors = ["Aayush Sabharwal and contributors"] version = "0.2.2" -[deps] -DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" - [compat] -DocStringExtensions = "0.9" julia = "1" [extras] diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index b0e4199..8892a69 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -1,13 +1,72 @@ module SymbolicIndexingInterface -using DocStringExtensions +export isvariable, variableindex, isparameter, parameterindex, isindependent_variable, + isobserved, observed, istimedependent, constant_structure -include("interface.jl") -include("symbolcache.jl") +""" + isvariable(sys, sym) -export independent_variables, is_indep_sym, states, state_sym_to_index, is_state_sym, - parameters, param_sym_to_index, is_param_sym, observed, observed_sym_to_index, - is_observed_sym, get_state_dependencies, get_observed_dependencies, - get_deps_of_observed, SymbolCache, unknown_states +Check whether the given `sym` is a variable in `sys`. +""" +function isvariable end + +""" + variableindex(sys, sym) + +Return the index of the given variable `sym` in `sys`, or `nothing` otherwise. +""" +function variableindex end + +""" + isparameter(sys, sym) + +Check whether the given `sym` is a parameter in `sys`. +""" +function isparameter end + +""" + parameterindex(sys, sym) + +Return the index of the given parameter `sym` in `sys`, or `nothing` otherwise. +""" +function parameterindex end + +""" + isindependent_variable(sys, sym) + +Check whether the given `sym` is an independent variable in `sys`. +""" +function isindependent_variable end + +""" + isobserved(sys, sym) + +Check whether the given `sym` is an observed value in `sys`. +""" +function isobserved end + +""" + observed(sys, sym) + +Return the observed function of the given `sym` in `sys`. The returned function should +accept a timeseries if `sys` has an independent variable, and return the observed +values for the given `sym`. +""" +function observed end + +""" + istimedependent(sys) + +Check if `sys` has time as (one of) its independent variables. +""" +function istimedependent end + +""" + constant_structure(sys) + +Check if `sys` has a constant structure. Constant structure systems do not change the +number of variables or parameters over time. +""" +function constant_structure end end diff --git a/src/interface.jl b/src/interface.jl deleted file mode 100644 index 588e1bb..0000000 --- a/src/interface.jl +++ /dev/null @@ -1,162 +0,0 @@ -""" -$(TYPEDSIGNATURES) - -Get an iterable over the independent variables for the given system. Default to an empty -vector. -""" -function independent_variables end -independent_variables(::Any) = [] - -""" -$(TYPEDSIGNATURES) - -Check if the given sym is an independent variable in the given system. Default to checking -if the given `sym` exists in the iterable returned by `independent_variables`. -""" -function is_indep_sym end - -function is_indep_sym(store, sym) - any(isequal(Symbol(sym)), Symbol.(independent_variables(store))) -end - -""" -$(TYPEDSIGNATURES) - -Get an iterable over the states for the given system. Default to an empty vector. -""" -function states end - -states(::Any) = [] - -""" -$(TYPEDSIGNATURES) - -Get an iterable over the unknown states for the given system. Default to an empty vector. -""" -function unknown_states end - -unknown_states(::Any) = [] - -""" -$(TYPEDSIGNATURES) - -Find the index of the given sym in the given system. Default to the index of the first -symbol in the iterable returned by `states` which matches the given `sym`. Return -`nothing` if the given `sym` does not match. -""" -function state_sym_to_index end - -function state_sym_to_index(store, sym) - findfirst(isequal(Symbol(sym)), Symbol.(states(store))) -end - -""" -$(TYPEDSIGNATURES) - -Check if the given sym is a state variable in the given system. Default to checking if -the value returned by `state_sym_to_index` is not `nothing`. -""" -function is_state_sym end - -is_state_sym(store, sym) = !isnothing(state_sym_to_index(store, sym)) - -""" -$(TYPEDSIGNATURES) - -Get an iterable over the parameters variables for the given system. Default to an empty -vector. -""" -function parameters end - -parameters(::Any) = [] - -""" -$(TYPEDSIGNATURES) - -Find the index of the given sym in the given system. Default to the index of the first -symbol in the iterable retruned by `parameters` which matches the given `sym`. Return -`nothing` if the given `sym` does not match. -""" -function param_sym_to_index end - -param_sym_to_index(store, sym) = findfirst(isequal(Symbol(sym)), Symbol.(parameters(store))) - -""" -$(TYPEDSIGNATURES) - -Check if the given sym is a parameter variable in the given system. Default -to checking if the value returned by `param_sym_to_index` is not `nothing`. -""" -function is_param_sym end - -is_param_sym(store, sym) = !isnothing(param_sym_to_index(store, sym)) - -""" -$(TYPEDSIGNATURES) - -Get an iterable over the observed variable expressions for the given system. -Default to an empty vector. -""" -function observed end - -observed(::Any) = [] - -""" -$(TYPEDSIGNATURES) - -Check if the given sym is an observed variable in the given system. Default -to checking if the value returned by `observed_sym_to_index` is not `nothing`. -""" -function is_observed_sym end - -is_observed_sym(store, sym) = !isnothing(observed_sym_to_index(store, sym)) - -""" -$(TYPEDSIGNATURES) - -Find the index of the given sym in the given system. Default to the index of the first -symbol in the iterable returned by `states` which matches the given `sym`. Return -`nothing` if the given `sym` does not match. -""" -function observed_sym_to_index end - -function observed_sym_to_index(store, sym) - findfirst(o -> isequal(sym, o.lhs), observed(store)) -end - -""" -$(TYPEDSIGNATURES) - -Return a list of the dependent state variables of an observed variable. Default to returning -an empty list. -""" -function get_state_dependencies end - -get_state_dependencies(store, sym) = [] - -""" -$(TYPEDSIGNATURES) - -Return a list of the dependent observed variables of an observed variable. Default to returning -an empty list. -""" -function get_observed_dependencies end - -get_observed_dependencies(store, sym) = [] - -""" -$(TYPEDSIGNATURES) - -Return a list of the dependent state variables of all observed equations of the system. -Default to returning an empty list. -""" -function get_deps_of_observed end - -function get_deps_of_observed(store) - obs = observed(store) - deps = mapreduce(vcat, obs, init = []) do eq - get_state_dependencies(store, eq.lhs) - end |> unique - - return deps -end diff --git a/src/symbolcache.jl b/src/symbolcache.jl deleted file mode 100644 index f9583ff..0000000 --- a/src/symbolcache.jl +++ /dev/null @@ -1,26 +0,0 @@ -""" - SymbolCache(syms, indepsym, paramsyms) - -A container that simply stores a vector of all syms, indepsym and paramsyms. -""" -struct SymbolCache{S, T, U} - syms::S - indepsym::T - paramsyms::U -end - -independent_variables(sc::SymbolCache) = sc.indepsym -independent_variables(::SymbolCache{S, Nothing}) where {S} = [] -is_indep_sym(::SymbolCache{S, Nothing}, _) where {S} = false -states(sc::SymbolCache) = sc.syms -states(::SymbolCache{Nothing}) = [] -state_sym_to_index(::SymbolCache{Nothing}, _) = nothing -parameters(sc::SymbolCache) = sc.paramsyms -parameters(::SymbolCache{S, T, Nothing}) where {S, T} = [] -param_sym_to_index(::SymbolCache{S, T, Nothing}, _) where {S, T} = nothing - -function Base.copy(VA::SymbolCache) - typeof(VA)((VA.syms === nothing) ? nothing : copy(VA.syms), - (VA.indepsym === nothing) ? nothing : copy(VA.indepsym), - (VA.paramsyms === nothing) ? nothing : copy(VA.paramsyms)) -end From cf92a76dd2fbe537f744c5ef9570dffe137a946c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Oct 2023 13:19:26 +0530 Subject: [PATCH 02/35] test: add tests --- test/default_function_test.jl | 15 -------- test/example_test.jl | 70 +++++++++++++++++++++++++++++++++++ test/runtests.jl | 5 ++- test/symbolcache.jl | 28 -------------- 4 files changed, 73 insertions(+), 45 deletions(-) delete mode 100644 test/default_function_test.jl create mode 100644 test/example_test.jl delete mode 100644 test/symbolcache.jl diff --git a/test/default_function_test.jl b/test/default_function_test.jl deleted file mode 100644 index be7b92e..0000000 --- a/test/default_function_test.jl +++ /dev/null @@ -1,15 +0,0 @@ -using SymbolicIndexingInterface, Test - -@test independent_variables(nothing) == [] -@test states(nothing) == [] -@test parameters(nothing) == [] -@test observed(nothing) == [] -@test !is_indep_sym(nothing, :a) -@test !is_state_sym(nothing, :a) -@test !is_param_sym(nothing, :a) -@test !is_observed_sym(nothing, :a) -@test isnothing(state_sym_to_index(nothing, :a)) -@test isnothing(param_sym_to_index(nothing, :a)) -@test get_state_dependencies(nothing, :a) == [] -@test get_observed_dependencies(nothing, :a) == [] -@test get_deps_of_observed(nothing) == [] diff --git a/test/example_test.jl b/test/example_test.jl new file mode 100644 index 0000000..2f23e5d --- /dev/null +++ b/test/example_test.jl @@ -0,0 +1,70 @@ +struct SymbolCache + vars::Vector{Symbol} + params::Vector{Symbol} + indepvar::Union{Symbol, Nothing} +end + +SymbolicIndexingInterface.isvariable(sys::SymbolCache, sym) = sym in sys.vars +function SymbolicIndexingInterface.variableindex(sys::SymbolCache, sym) + findfirst(isequal(sym), sys.vars) +end +SymbolicIndexingInterface.isparameter(sys::SymbolCache, sym) = sym in sys.params +function SymbolicIndexingInterface.parameterindex(sys::SymbolCache, sym) + findfirst(isequal(sym), sys.params) +end +function SymbolicIndexingInterface.isindependent_variable(sys::SymbolCache, sym) + sys.indepvar !== nothing && isequal(sym, sys.indepvar) +end +function SymbolicIndexingInterface.isobserved(sys::SymbolCache, sym) + isvariable(sys, sym) || isparameter(sys, sym) || isindependent_variable(sys, sym) +end +function SymbolicIndexingInterface.observed(sys::SymbolCache, sym) + idx = variableindex(sys, sym) + if idx !== nothing + return istimedependent(sys) ? (t) -> [idx * i for i in t] : + () -> [idx * i for i in 1:5] + end + idx = parameterindex(sys, sym) + if idx !== nothing + return istimedependent(sys) ? (t) -> idx : () -> idx + end + if isindependent_variable(sys, sym) + return istimedependent(sys) ? (t) -> t : () -> 1:5 + end +end +SymbolicIndexingInterface.istimedependent(sys::SymbolCache) = isequal(sys.indepvar, :t) +SymbolicIndexingInterface.constant_structure(sys::SymbolCache) = true + +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) + +@test all(isvariable.((sys,), [:x, :y, :z])) +@test all(.!isvariable.((sys,), [:a, :b, :c, :t, :p, :q, :r])) +@test all(variableindex.((sys,), [:x, :z, :y]) .== [1, 3, 2]) +@test all(variableindex.((sys,), [:a, :b, :c, :t, :p, :q, :r]) .=== nothing) +@test all(isparameter.((sys,), [:a, :b, :c])) +@test all(.!isparameter.((sys,), [:x, :y, :z, :t, :p, :q, :r])) +@test all(parameterindex.((sys,), [:c, :a, :b]) .== [3, 1, 2]) +@test all(parameterindex.((sys,), [:x, :y, :z, :t, :p, :q, :r]) .=== nothing) +@test isindependent_variable(sys, :t) +@test all(.!isindependent_variable.((sys,), [:x, :y, :z, :a, :b, :c, :p, :q, :r])) +@test all(isobserved.((sys,), [:x, :y, :z, :a, :b, :c, :t])) +@test all(observed(sys, :x)(1:4) .== [1, 2, 3, 4]) +@test all(observed(sys, :y)(1:4) .== [2, 4, 6, 8]) +@test all(observed(sys, :z)(1:4) .== [3, 6, 9, 12]) +@test observed(sys, :a)(1:4) == 1 +@test observed(sys, :b)(1:4) == 2 +@test observed(sys, :c)(1:4) == 3 +@test observed(sys, :t)(1:4) == 1:4 +@test istimedependent(sys) +@test constant_structure(sys) + +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], nothing) + +@test !istimedependent(sys) +@test all(observed(sys, :x)() .== [1, 2, 3, 4, 5]) +@test all(observed(sys, :y)() .== [2, 4, 6, 8, 10]) +@test all(observed(sys, :z)() .== [3, 6, 9, 12, 15]) +@test observed(sys, :a)() == 1 +@test observed(sys, :b)() == 2 +@test observed(sys, :c)() == 3 +@test constant_structure(sys) diff --git a/test/runtests.jl b/test/runtests.jl index 50a538a..eb0fc78 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using SymbolicIndexingInterface using Test -@time begin @time @testset begin include("symbolcache.jl") end end -@time begin @time @testset begin include("default_function_test.jl") end end +@testset begin + @time include("example_test.jl") +end diff --git a/test/symbolcache.jl b/test/symbolcache.jl deleted file mode 100644 index 04ca4fa..0000000 --- a/test/symbolcache.jl +++ /dev/null @@ -1,28 +0,0 @@ -using SymbolicIndexingInterface, Test - -sc = SymbolCache(nothing, nothing, nothing) -@test isempty(independent_variables(sc)) -@test !is_indep_sym(sc, :a) -@test isempty(states(sc)) -@test isnothing(state_sym_to_index(sc, :a)) -@test !is_state_sym(sc, :a) -@test isempty(parameters(sc)) -@test isnothing(param_sym_to_index(sc, :a)) -@test !is_param_sym(sc, :a) - -sc = SymbolCache([:a, :b], [:t], [:c, :d]) -@test independent_variables(sc) == [:t] -@test is_indep_sym(sc, :t) -@test !is_indep_sym(sc, :a) -@test states(sc) == [:a, :b] -@test state_sym_to_index(sc, :a) == 1 -@test state_sym_to_index(sc, :b) == 2 -@test isnothing(state_sym_to_index(sc, :t)) -@test all(is_state_sym.((sc,), [:a, :b])) -@test !is_state_sym(sc, :c) -@test parameters(sc) == [:c, :d] -@test param_sym_to_index(sc, :c) == 1 -@test param_sym_to_index(sc, :d) == 2 -@test isnothing(param_sym_to_index(sc, :a)) -@test all(is_param_sym.((sc,), [:c, :d])) -@test !is_param_sym(sc, :b) From ced1e656b97801718820324213640007a6715a99 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Oct 2023 13:19:39 +0530 Subject: [PATCH 03/35] refactor: formatting --- docs/make.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index afd6757..23c8508 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -6,13 +6,13 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true) include("pages.jl") makedocs(sitename = "SymbolicIndexingInterface.jl", - authors = "Chris Rackauckas", - modules = [SymbolicIndexingInterface], - clean = true, doctest = false, - format = Documenter.HTML(analytics = "UA-90474609-3", - assets = ["assets/favicon.ico"], - canonical = "https://docs.sciml.ai/SymbolicIndexingInterface/stable/"), - pages = pages) + authors = "Chris Rackauckas", + modules = [SymbolicIndexingInterface], + clean = true, doctest = false, + format = Documenter.HTML(analytics = "UA-90474609-3", + assets = ["assets/favicon.ico"], + canonical = "https://docs.sciml.ai/SymbolicIndexingInterface/stable/"), + pages = pages) deploydocs(repo = "github.com/SciML/SymbolicIndexingInterface.jl.git"; - push_preview = true) + push_preview = true) From b7f47d2aea3b99dfa276d7cf1971fd9ad61b27f3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Oct 2023 13:22:29 +0530 Subject: [PATCH 04/35] feat: bump major version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 676e622..82ee6fd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicIndexingInterface" uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" authors = ["Aayush Sabharwal and contributors"] -version = "0.2.2" +version = "0.3.0" [compat] julia = "1" From 08d8dbb6723cfb501acfb332627f013fcabbceea Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Oct 2023 13:42:35 +0530 Subject: [PATCH 05/35] docs: update docs compat and docstrings --- docs/Project.toml | 2 +- docs/src/api.md | 25 +++++++++---------------- docs/src/index.md | 2 +- 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index f2af07c..c81e0b6 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,4 +4,4 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" [compat] Documenter = "0.27" -SymbolicIndexingInterface = "0.2" +SymbolicIndexingInterface = "0.3" diff --git a/docs/src/api.md b/docs/src/api.md index c29ad76..79aea93 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,20 +1,13 @@ # Interface Functions -Default methods cast all symbols to `Symbol` before comparing. - ```@docs -independent_variables -is_indep_sym -states -state_sym_to_index -is_state_sym -parameters -param_sym_to_index -is_param_sym +isvariable +variableindex +isparameter +parameterindex +isindependent_variable +isobserved +observed +istimedependent +constant_structure ``` - -## Concrete Types - -```@docs -SymbolCache -``` \ No newline at end of file diff --git a/docs/src/index.md b/docs/src/index.md index 6e6f536..7d6b3e9 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,7 +1,7 @@ # SymbolicIndexingInterface.jl: Arrays of Arrays and Even Deeper SymbolicIndexingInterface.jl is a set of interface functions for handling containers -of symbolic variables. It also contains one such container: `SymbolCache`. +of symbolic variables. ## Installation From 5a96b7a94759c57b481dd318b243af1900008d46 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Oct 2023 20:49:14 +0530 Subject: [PATCH 06/35] feat: add `IsSymbolic` trait --- src/SymbolicIndexingInterface.jl | 36 ++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 8892a69..32d4bde 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -1,7 +1,39 @@ module SymbolicIndexingInterface -export isvariable, variableindex, isparameter, parameterindex, isindependent_variable, - isobserved, observed, istimedependent, constant_structure +export Symbolic, NotSymbolic, issymbolic, isvariable, variableindex, isparameter, + parameterindex, isindependent_variable, isobserved, observed, istimedependent, + constant_structure + +abstract type IsSymbolicTrait end + +""" + struct Symbolic <: IsSymbolicTrait end + +Trait indicating a type is symbolic. + +See also: [`NotSymbolic`](@ref), [`issymbolic`](@ref) +""" +struct Symbolic <: IsSymbolicTrait end + +""" + struct NotSymbolic <: IsSymbolicTrait end + +Trait indicating a type is not symbolic. + +See also: [`Symbolic`](@ref), [`issymbolic`](@ref) +""" +struct NotSymbolic <: IsSymbolicTrait end + +""" + issymbolic(x) = issymbolic(typeof(x)) + issymbolic(::Type) + +Check whether a type implements the [`Symbolic`](@ref) trait or not. Default to +[`NotSymbolic`](@ref) for all types except `Symbol`. +""" +issymbolic(x) = issymbolic(typeof(x)) +issymbolic(::Type) = NotSymbolic() +issymbolic(::Type{Symbol}) = Symbolic() """ isvariable(sys, sym) From 37eda772eac8aa8330a71e180f549c9e22572eba Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Oct 2023 20:49:30 +0530 Subject: [PATCH 07/35] feat: add extensions for Symbolics and SymbolicUtils --- Project.toml | 10 ++++++++++ ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl | 7 +++++++ ext/SymbolicIndexingInterfaceSymbolicsExt.jl | 7 +++++++ 3 files changed, 24 insertions(+) create mode 100644 ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl create mode 100644 ext/SymbolicIndexingInterfaceSymbolicsExt.jl diff --git a/Project.toml b/Project.toml index 82ee6fd..f817e87 100644 --- a/Project.toml +++ b/Project.toml @@ -3,8 +3,18 @@ uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" authors = ["Aayush Sabharwal and contributors"] version = "0.3.0" +[weakdeps] +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" + +[extensions] +SymbolicIndexingInterfaceSymbolicsExt = "Symbolics" +SymbolicIndexingInterfaceSymbolicUtilsExt = "SymbolicUtils" + [compat] julia = "1" +Symbolics = "5.10" +SymbolicUtils = "1.4" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl b/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl new file mode 100644 index 0000000..ea70316 --- /dev/null +++ b/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl @@ -0,0 +1,7 @@ +module SymbolicIndexingInterfaceSymbolicUtilsExt + +using SymbolicIndexingInterface, SymbolicUtils + +SymbolicIndexingInterface.issymbolic(::Type{<:SymbolicUtils.BasicSymbolic}) = Symbolic() + +end diff --git a/ext/SymbolicIndexingInterfaceSymbolicsExt.jl b/ext/SymbolicIndexingInterfaceSymbolicsExt.jl new file mode 100644 index 0000000..dd767a9 --- /dev/null +++ b/ext/SymbolicIndexingInterfaceSymbolicsExt.jl @@ -0,0 +1,7 @@ +module SymbolicIndexingInterfaceSymbolicsExt + +using SymbolicIndexingInterface, Symbolics + +SymbolicIndexingInterface.issymbolic(::Type{<:Symbolics.Num}) = Symbolic() + +end From 423e6111aa96130312610d9cc784ed52d2e7087f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Oct 2023 21:38:20 +0530 Subject: [PATCH 08/35] refactor: split into multiple files --- Project.toml | 8 +-- src/SymbolicIndexingInterface.jl | 103 ++----------------------------- src/interface.jl | 65 +++++++++++++++++++ src/trait.jl | 30 +++++++++ 4 files changed, 104 insertions(+), 102 deletions(-) create mode 100644 src/interface.jl create mode 100644 src/trait.jl diff --git a/Project.toml b/Project.toml index f817e87..19f2bb9 100644 --- a/Project.toml +++ b/Project.toml @@ -4,17 +4,17 @@ authors = ["Aayush Sabharwal and contributors"] version = "0.3.0" [weakdeps] -Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [extensions] -SymbolicIndexingInterfaceSymbolicsExt = "Symbolics" SymbolicIndexingInterfaceSymbolicUtilsExt = "SymbolicUtils" +SymbolicIndexingInterfaceSymbolicsExt = "Symbolics" [compat] -julia = "1" -Symbolics = "5.10" SymbolicUtils = "1.4" +Symbolics = "5.10" +julia = "1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 32d4bde..f2f2013 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -1,104 +1,11 @@ module SymbolicIndexingInterface -export Symbolic, NotSymbolic, issymbolic, isvariable, variableindex, isparameter, - parameterindex, isindependent_variable, isobserved, observed, istimedependent, - constant_structure +export Symbolic, NotSymbolic +include("trait.jl") -abstract type IsSymbolicTrait end +export issymbolic, isvariable, variableindex, isparameter, parameterindex, + isindependent_variable, isobserved, observed, istimedependent, constant_structure +include("interface.jl") -""" - struct Symbolic <: IsSymbolicTrait end - -Trait indicating a type is symbolic. - -See also: [`NotSymbolic`](@ref), [`issymbolic`](@ref) -""" -struct Symbolic <: IsSymbolicTrait end - -""" - struct NotSymbolic <: IsSymbolicTrait end - -Trait indicating a type is not symbolic. - -See also: [`Symbolic`](@ref), [`issymbolic`](@ref) -""" -struct NotSymbolic <: IsSymbolicTrait end - -""" - issymbolic(x) = issymbolic(typeof(x)) - issymbolic(::Type) - -Check whether a type implements the [`Symbolic`](@ref) trait or not. Default to -[`NotSymbolic`](@ref) for all types except `Symbol`. -""" -issymbolic(x) = issymbolic(typeof(x)) -issymbolic(::Type) = NotSymbolic() -issymbolic(::Type{Symbol}) = Symbolic() - -""" - isvariable(sys, sym) - -Check whether the given `sym` is a variable in `sys`. -""" -function isvariable end - -""" - variableindex(sys, sym) - -Return the index of the given variable `sym` in `sys`, or `nothing` otherwise. -""" -function variableindex end - -""" - isparameter(sys, sym) - -Check whether the given `sym` is a parameter in `sys`. -""" -function isparameter end - -""" - parameterindex(sys, sym) - -Return the index of the given parameter `sym` in `sys`, or `nothing` otherwise. -""" -function parameterindex end - -""" - isindependent_variable(sys, sym) - -Check whether the given `sym` is an independent variable in `sys`. -""" -function isindependent_variable end - -""" - isobserved(sys, sym) - -Check whether the given `sym` is an observed value in `sys`. -""" -function isobserved end - -""" - observed(sys, sym) - -Return the observed function of the given `sym` in `sys`. The returned function should -accept a timeseries if `sys` has an independent variable, and return the observed -values for the given `sym`. -""" -function observed end - -""" - istimedependent(sys) - -Check if `sys` has time as (one of) its independent variables. -""" -function istimedependent end - -""" - constant_structure(sys) - -Check if `sys` has a constant structure. Constant structure systems do not change the -number of variables or parameters over time. -""" -function constant_structure end end diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 0000000..c2d8796 --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,65 @@ +""" + isvariable(sys, sym) + +Check whether the given `sym` is a variable in `sys`. +""" +function isvariable end + +""" + variableindex(sys, sym) + +Return the index of the given variable `sym` in `sys`, or `nothing` otherwise. +""" +function variableindex end + +""" + isparameter(sys, sym) + +Check whether the given `sym` is a parameter in `sys`. +""" +function isparameter end + +""" + parameterindex(sys, sym) + +Return the index of the given parameter `sym` in `sys`, or `nothing` otherwise. +""" +function parameterindex end + +""" + isindependent_variable(sys, sym) + +Check whether the given `sym` is an independent variable in `sys`. +""" +function isindependent_variable end + +""" + isobserved(sys, sym) + +Check whether the given `sym` is an observed value in `sys`. +""" +function isobserved end + +""" + observed(sys, sym) + +Return the observed function of the given `sym` in `sys`. The returned function should +accept a timeseries if `sys` has an independent variable, and return the observed +values for the given `sym`. +""" +function observed end + +""" + istimedependent(sys) + +Check if `sys` has time as (one of) its independent variables. +""" +function istimedependent end + +""" + constant_structure(sys) + +Check if `sys` has a constant structure. Constant structure systems do not change the +number of variables or parameters over time. +""" +function constant_structure end diff --git a/src/trait.jl b/src/trait.jl new file mode 100644 index 0000000..14936f3 --- /dev/null +++ b/src/trait.jl @@ -0,0 +1,30 @@ +abstract type IsSymbolicTrait end + +""" + struct Symbolic <: IsSymbolicTrait end + +Trait indicating a type is symbolic. + +See also: [`NotSymbolic`](@ref), [`issymbolic`](@ref) +""" +struct Symbolic <: IsSymbolicTrait end + +""" + struct NotSymbolic <: IsSymbolicTrait end + +Trait indicating a type is not symbolic. + +See also: [`Symbolic`](@ref), [`issymbolic`](@ref) +""" +struct NotSymbolic <: IsSymbolicTrait end + +""" + issymbolic(x) = issymbolic(typeof(x)) + issymbolic(::Type) + +Check whether a type implements the [`Symbolic`](@ref) trait or not. Default to +[`NotSymbolic`](@ref) for all types except `Symbol`. +""" +issymbolic(x) = issymbolic(typeof(x)) +issymbolic(::Type) = NotSymbolic() +issymbolic(::Type{Symbol}) = Symbolic() From 1fe9d57068e61e30dbb999f9979acae1ffcbab28 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Oct 2023 21:39:15 +0530 Subject: [PATCH 09/35] feat: add pre-1.9 compat using Requires --- Project.toml | 4 ++++ src/SymbolicIndexingInterface.jl | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/Project.toml b/Project.toml index 19f2bb9..36da873 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,9 @@ uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" authors = ["Aayush Sabharwal and contributors"] version = "0.3.0" +[deps] +Requires = "ae029012-a4dd-5104-9daa-d747884805df" + [weakdeps] SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" @@ -12,6 +15,7 @@ SymbolicIndexingInterfaceSymbolicUtilsExt = "SymbolicUtils" SymbolicIndexingInterfaceSymbolicsExt = "Symbolics" [compat] +Requires = "1.0" SymbolicUtils = "1.4" Symbolics = "5.10" julia = "1" diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index f2f2013..cbaf9d3 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -7,5 +7,12 @@ export issymbolic, isvariable, variableindex, isparameter, parameterindex, isindependent_variable, isobserved, observed, istimedependent, constant_structure include("interface.jl") +@static if !isdefined(Base, :get_extension) + using Requires + function __init__() + @require Symbolics="0c5d862f-8b57-4792-8d23-62f2024744c7" include("../ext/SymbolicIndexingInterfaceSymbolicsExt.jl") + @require SymbolicUtils="d1185830-fcd6-423d-90d6-eec64667417b" include("../ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl") + end +end end From 95e6c1b6021433f4edd7d7dfc0d364c967b96097 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Oct 2023 22:21:10 +0530 Subject: [PATCH 10/35] fix: conditional import in extensions --- ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl | 7 ++++++- ext/SymbolicIndexingInterfaceSymbolicsExt.jl | 8 +++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl b/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl index ea70316..df48aa3 100644 --- a/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl +++ b/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl @@ -1,6 +1,11 @@ module SymbolicIndexingInterfaceSymbolicUtilsExt -using SymbolicIndexingInterface, SymbolicUtils +using SymbolicIndexingInterface +@static if isdefined(Base, :get_extension) + using SymbolicUtils +else + using ..SymbolicUtils +end SymbolicIndexingInterface.issymbolic(::Type{<:SymbolicUtils.BasicSymbolic}) = Symbolic() diff --git a/ext/SymbolicIndexingInterfaceSymbolicsExt.jl b/ext/SymbolicIndexingInterfaceSymbolicsExt.jl index dd767a9..f8acbbf 100644 --- a/ext/SymbolicIndexingInterfaceSymbolicsExt.jl +++ b/ext/SymbolicIndexingInterfaceSymbolicsExt.jl @@ -1,6 +1,12 @@ module SymbolicIndexingInterfaceSymbolicsExt -using SymbolicIndexingInterface, Symbolics +using SymbolicIndexingInterface + +@static if isdefined(Base, :get_extension) + using Symbolics +else + using ..Symbolics +end SymbolicIndexingInterface.issymbolic(::Type{<:Symbolics.Num}) = Symbolic() From 9bea16466126a7c57c8305c11597aea8beb190b7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 23 Oct 2023 11:51:36 +0530 Subject: [PATCH 11/35] feat: support nonstatic states, normalize function names - add has_static_variable function - make variable names more consistent --- src/SymbolicIndexingInterface.jl | 6 +- src/interface.jl | 50 ++++++++----- test/example_test.jl | 119 +++++++++++++++++++------------ 3 files changed, 110 insertions(+), 65 deletions(-) diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index cbaf9d3..7d09703 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -1,10 +1,10 @@ module SymbolicIndexingInterface -export Symbolic, NotSymbolic +export Symbolic, NotSymbolic, issymbolic include("trait.jl") -export issymbolic, isvariable, variableindex, isparameter, parameterindex, - isindependent_variable, isobserved, observed, istimedependent, constant_structure +export is_variable, has_static_variable, variable_index, is_parameter, parameter_index, + is_independent_variable, is_observed, observed, is_time_dependent, constant_structure include("interface.jl") @static if !isdefined(Base, :get_extension) diff --git a/src/interface.jl b/src/interface.jl index c2d8796..016bdf6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,60 +1,74 @@ """ - isvariable(sys, sym) + is_variable(sys, sym) Check whether the given `sym` is a variable in `sys`. """ -function isvariable end +function is_variable end """ - variableindex(sys, sym) + has_static_variable(sys) -Return the index of the given variable `sym` in `sys`, or `nothing` otherwise. +Check whether the variables in `sys` are stable across time. """ -function variableindex end +function has_static_variable end """ - isparameter(sys, sym) + variable_index(sys, sym, [i]) + +Return the index of the given variable `sym` in `sys`, or `nothing` otherwise. If +[`has_static_variable`](@ref) is `false`, this accepts the current timestep as an +additional parameter `i`. +""" +function variable_index end + +""" + is_parameter(sys, sym) Check whether the given `sym` is a parameter in `sys`. """ -function isparameter end +function is_parameter end """ - parameterindex(sys, sym) + parameter_index(sys, sym) Return the index of the given parameter `sym` in `sys`, or `nothing` otherwise. """ -function parameterindex end +function parameter_index end """ - isindependent_variable(sys, sym) + is_independent_variable(sys, sym) Check whether the given `sym` is an independent variable in `sys`. """ -function isindependent_variable end +function is_independent_variable end """ - isobserved(sys, sym) + is_observed(sys, sym) Check whether the given `sym` is an observed value in `sys`. """ -function isobserved end +function is_observed end """ - observed(sys, sym) + observed(sys, sym, [symbolic_states]) Return the observed function of the given `sym` in `sys`. The returned function should -accept a timeseries if `sys` has an independent variable, and return the observed -values for the given `sym`. +have the signature `(u, p) -> [values...]` where `u` and `p` is the current state and +parameter vector. If `istimedependent(sys) == true`, the function should accept +the current time `t` as its third parameter. If `has_static_variable(sys) == false` then +`observed` must accept a third parameter `symbolic_states` indicating the order of symbolic +variables in `u`. + +See also: [`is_time_dependent`](@ref), [`has_static_variable`](@ref) """ function observed end """ - istimedependent(sys) + is_time_dependent(sys) Check if `sys` has time as (one of) its independent variables. """ -function istimedependent end +function is_time_dependent end """ constant_structure(sys) diff --git a/test/example_test.jl b/test/example_test.jl index 2f23e5d..7c0d20a 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -1,70 +1,101 @@ struct SymbolCache + static::Bool vars::Vector{Symbol} params::Vector{Symbol} indepvar::Union{Symbol, Nothing} end -SymbolicIndexingInterface.isvariable(sys::SymbolCache, sym) = sym in sys.vars -function SymbolicIndexingInterface.variableindex(sys::SymbolCache, sym) +SymbolicIndexingInterface.is_variable(sys::SymbolCache, sym) = sym in sys.vars +function SymbolicIndexingInterface.variable_index(sys::SymbolCache, sym, t = nothing) + if !has_static_variable(sys) && t === nothing + error("timestep must be present") + end findfirst(isequal(sym), sys.vars) end -SymbolicIndexingInterface.isparameter(sys::SymbolCache, sym) = sym in sys.params -function SymbolicIndexingInterface.parameterindex(sys::SymbolCache, sym) +SymbolicIndexingInterface.is_parameter(sys::SymbolCache, sym) = sym in sys.params +function SymbolicIndexingInterface.parameter_index(sys::SymbolCache, sym) findfirst(isequal(sym), sys.params) end -function SymbolicIndexingInterface.isindependent_variable(sys::SymbolCache, sym) +function SymbolicIndexingInterface.is_independent_variable(sys::SymbolCache, sym) sys.indepvar !== nothing && isequal(sym, sys.indepvar) end -function SymbolicIndexingInterface.isobserved(sys::SymbolCache, sym) - isvariable(sys, sym) || isparameter(sys, sym) || isindependent_variable(sys, sym) +function SymbolicIndexingInterface.is_observed(sys::SymbolCache, sym) + is_variable(sys, sym) || is_parameter(sys, sym) || is_independent_variable(sys, sym) end -function SymbolicIndexingInterface.observed(sys::SymbolCache, sym) - idx = variableindex(sys, sym) +function SymbolicIndexingInterface.observed(sys::SymbolCache, + sym, + symbolic_states = nothing) + if !has_static_variable(sys) && symbolic_states === nothing + error("Symbolic states must be present") + end + if has_static_variable(sys) + symbolic_states = sys.vars + end + idx = findfirst(isequal(sym), symbolic_states) if idx !== nothing - return istimedependent(sys) ? (t) -> [idx * i for i in t] : - () -> [idx * i for i in 1:5] + return is_time_dependent(sys) ? (u, p, t) -> [u[idx] * i for i in t] : + (u, p) -> [u[idx] * i for i in 1:5] end - idx = parameterindex(sys, sym) + idx = parameter_index(sys, sym) if idx !== nothing - return istimedependent(sys) ? (t) -> idx : () -> idx + return is_time_dependent(sys) ? (u, p, t) -> p[idx] : (u, p) -> p[idx] end - if isindependent_variable(sys, sym) - return istimedependent(sys) ? (t) -> t : () -> 1:5 + if is_independent_variable(sys, sym) + return is_time_dependent(sys) ? (u, p, t) -> t : (u, p) -> 1:5 end end -SymbolicIndexingInterface.istimedependent(sys::SymbolCache) = isequal(sys.indepvar, :t) +SymbolicIndexingInterface.is_time_dependent(sys::SymbolCache) = isequal(sys.indepvar, :t) SymbolicIndexingInterface.constant_structure(sys::SymbolCache) = true +SymbolicIndexingInterface.has_static_variable(sys::SymbolCache) = sys.static -sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) +sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], :t) -@test all(isvariable.((sys,), [:x, :y, :z])) -@test all(.!isvariable.((sys,), [:a, :b, :c, :t, :p, :q, :r])) -@test all(variableindex.((sys,), [:x, :z, :y]) .== [1, 3, 2]) -@test all(variableindex.((sys,), [:a, :b, :c, :t, :p, :q, :r]) .=== nothing) -@test all(isparameter.((sys,), [:a, :b, :c])) -@test all(.!isparameter.((sys,), [:x, :y, :z, :t, :p, :q, :r])) -@test all(parameterindex.((sys,), [:c, :a, :b]) .== [3, 1, 2]) -@test all(parameterindex.((sys,), [:x, :y, :z, :t, :p, :q, :r]) .=== nothing) -@test isindependent_variable(sys, :t) -@test all(.!isindependent_variable.((sys,), [:x, :y, :z, :a, :b, :c, :p, :q, :r])) -@test all(isobserved.((sys,), [:x, :y, :z, :a, :b, :c, :t])) -@test all(observed(sys, :x)(1:4) .== [1, 2, 3, 4]) -@test all(observed(sys, :y)(1:4) .== [2, 4, 6, 8]) -@test all(observed(sys, :z)(1:4) .== [3, 6, 9, 12]) -@test observed(sys, :a)(1:4) == 1 -@test observed(sys, :b)(1:4) == 2 -@test observed(sys, :c)(1:4) == 3 -@test observed(sys, :t)(1:4) == 1:4 -@test istimedependent(sys) +@test all(is_variable.((sys,), [:x, :y, :z])) +@test all(.!is_variable.((sys,), [:a, :b, :c, :t, :p, :q, :r])) +@test all(variable_index.((sys,), [:x, :z, :y]) .== [1, 3, 2]) +@test all(variable_index.((sys,), [:a, :b, :c, :t, :p, :q, :r]) .=== nothing) +@test all(is_parameter.((sys,), [:a, :b, :c])) +@test all(.!is_parameter.((sys,), [:x, :y, :z, :t, :p, :q, :r])) +@test all(parameter_index.((sys,), [:c, :a, :b]) .== [3, 1, 2]) +@test all(parameter_index.((sys,), [:x, :y, :z, :t, :p, :q, :r]) .=== nothing) +@test is_independent_variable(sys, :t) +@test all(.!is_independent_variable.((sys,), [:x, :y, :z, :a, :b, :c, :p, :q, :r])) +@test all(is_observed.((sys,), [:x, :y, :z, :a, :b, :c, :t])) +@test all(observed(sys, :x)(1:3, 4:6, 1:4) .== [1, 2, 3, 4]) +@test all(observed(sys, :y)(1:4, 4:6, 1:4) .== [2, 4, 6, 8]) +@test all(observed(sys, :z)(1:4, 4:6, 1:4) .== [3, 6, 9, 12]) +@test observed(sys, :a)(1:3, 4:6, 1:4) == 4 +@test observed(sys, :b)(1:3, 4:6, 1:4) == 5 +@test observed(sys, :c)(1:3, 4:6, 1:4) == 6 +@test observed(sys, :t)(1:3, 4:6, 1:4) == 1:4 +@test is_time_dependent(sys) @test constant_structure(sys) +@test has_static_variable(sys) -sys = SymbolCache([:x, :y, :z], [:a, :b, :c], nothing) +sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], nothing) -@test !istimedependent(sys) -@test all(observed(sys, :x)() .== [1, 2, 3, 4, 5]) -@test all(observed(sys, :y)() .== [2, 4, 6, 8, 10]) -@test all(observed(sys, :z)() .== [3, 6, 9, 12, 15]) -@test observed(sys, :a)() == 1 -@test observed(sys, :b)() == 2 -@test observed(sys, :c)() == 3 +@test !is_time_dependent(sys) +@test all(observed(sys, :x)(1:3, 4:6) .== [1, 2, 3, 4, 5]) +@test all(observed(sys, :y)(1:3, 4:6) .== [2, 4, 6, 8, 10]) +@test all(observed(sys, :z)(1:3, 4:6) .== [3, 6, 9, 12, 15]) +@test observed(sys, :a)(1:3, 4:6) == 4 +@test observed(sys, :b)(1:3, 4:6) == 5 +@test observed(sys, :c)(1:3, 4:6) == 6 @test constant_structure(sys) + +sys = SymbolCache(false, [:x, :y, :z], [:a, :b, :c], :t) +@test !has_static_variable(sys) +for variable in [:x, :y, :z, :a, :b, :c, :t] + @test_throws ErrorException variable_index(sys, variable) + @test_throws ErrorException observed(sys, variable) +end +@test all(variable_index.((sys,), [:z, :y, :x], 1) .== [3, 2, 1]) +@test all(variable_index.((sys,), [:a, :b, :c, :t], 1) .== nothing) +variable_order = [:x, :y, :z] +@test all(observed(sys, :x, variable_order)(1:3, 4:6, 1:4) .== [1, 2, 3, 4]) +@test all(observed(sys, :y, variable_order)(1:4, 4:6, 1:4) .== [2, 4, 6, 8]) +@test all(observed(sys, :z, variable_order)(1:4, 4:6, 1:4) .== [3, 6, 9, 12]) +@test observed(sys, :a, variable_order)(1:3, 4:6, 1:4) == 4 +@test observed(sys, :b, variable_order)(1:3, 4:6, 1:4) == 5 +@test observed(sys, :c, variable_order)(1:3, 4:6, 1:4) == 6 +@test observed(sys, :t, variable_order)(1:3, 4:6, 1:4) == 1:4 From 859e130805617a95bea808083f7d67f77be3f767 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 23 Oct 2023 12:02:28 +0530 Subject: [PATCH 12/35] test: add trait tests --- test/runtests.jl | 5 ++++- test/trait_test.jl | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 test/trait_test.jl diff --git a/test/runtests.jl b/test/runtests.jl index eb0fc78..7cc19d9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,9 @@ using SymbolicIndexingInterface using Test -@testset begin +@testset "Interface test" begin @time include("example_test.jl") end +@testset "Trait test" begin + @time include("trait_test.jl") +end diff --git a/test/trait_test.jl b/test/trait_test.jl new file mode 100644 index 0000000..9665509 --- /dev/null +++ b/test/trait_test.jl @@ -0,0 +1,14 @@ +using SymbolicUtils +using Symbolics +using SymbolicIndexingInterface +using Test + +@test all(issymbolic.([Int, Float64, String, Bool, UInt, Complex{Float64}]) .== + (NotSymbolic(),)) +@test all(issymbolic.([Symbol, SymbolicUtils.BasicSymbolic, Symbolics.Num]) .== + (Symbolic(),)) +@variables x +@test issymbolic(x) == Symbolic() +@variables y[1:3] +@test issymbolic(y) == NotSymbolic() +@test all(issymbolic.(collect(y)) .== (Symbolic(),)) From 51ad41ffaf47c1da702c982e8e7d0e9bbe6cff88 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 23 Oct 2023 12:07:26 +0530 Subject: [PATCH 13/35] fixup! test: add trait tests --- Project.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 36da873..a139bb4 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,9 @@ Symbolics = "5.10" julia = "1" [extras] +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["SymbolicUtils", "Symbolics","Test"] From 66e59f28dec5f02285105418a1e049fb80a6146f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 23 Oct 2023 12:37:24 +0530 Subject: [PATCH 14/35] fix: move Requires.jl import --- src/SymbolicIndexingInterface.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 7d09703..75dc963 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -1,5 +1,7 @@ module SymbolicIndexingInterface +using Requires + export Symbolic, NotSymbolic, issymbolic include("trait.jl") @@ -8,7 +10,6 @@ export is_variable, has_static_variable, variable_index, is_parameter, parameter include("interface.jl") @static if !isdefined(Base, :get_extension) - using Requires function __init__() @require Symbolics="0c5d862f-8b57-4792-8d23-62f2024744c7" include("../ext/SymbolicIndexingInterfaceSymbolicsExt.jl") @require SymbolicUtils="d1185830-fcd6-423d-90d6-eec64667417b" include("../ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl") From 52fc726e26c52f04b9dc89b9a01cfa8c28022b1c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 27 Oct 2023 11:27:17 +0530 Subject: [PATCH 15/35] feat: rework trait --- ...mbolicIndexingInterfaceSymbolicUtilsExt.jl | 4 +- ext/SymbolicIndexingInterfaceSymbolicsExt.jl | 3 +- src/SymbolicIndexingInterface.jl | 2 +- src/trait.jl | 41 ++++++++++++------- test/trait_test.jl | 13 +++--- 5 files changed, 39 insertions(+), 24 deletions(-) diff --git a/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl b/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl index df48aa3..0b202b9 100644 --- a/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl +++ b/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl @@ -7,6 +7,8 @@ else using ..SymbolicUtils end -SymbolicIndexingInterface.issymbolic(::Type{<:SymbolicUtils.BasicSymbolic}) = Symbolic() +function SymbolicIndexingInterface.symbolic_type(::Type{<:SymbolicUtils.BasicSymbolic}) + ScalarSymbolic() +end end diff --git a/ext/SymbolicIndexingInterfaceSymbolicsExt.jl b/ext/SymbolicIndexingInterfaceSymbolicsExt.jl index f8acbbf..662d53f 100644 --- a/ext/SymbolicIndexingInterfaceSymbolicsExt.jl +++ b/ext/SymbolicIndexingInterfaceSymbolicsExt.jl @@ -8,6 +8,7 @@ else using ..Symbolics end -SymbolicIndexingInterface.issymbolic(::Type{<:Symbolics.Num}) = Symbolic() +SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic() +SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic() end diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 75dc963..51fa896 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -2,7 +2,7 @@ module SymbolicIndexingInterface using Requires -export Symbolic, NotSymbolic, issymbolic +export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type include("trait.jl") export is_variable, has_static_variable, variable_index, is_parameter, parameter_index, diff --git a/src/trait.jl b/src/trait.jl index 14936f3..8b2a95a 100644 --- a/src/trait.jl +++ b/src/trait.jl @@ -1,30 +1,41 @@ -abstract type IsSymbolicTrait end +abstract type SymbolicTypeTrait end """ - struct Symbolic <: IsSymbolicTrait end + struct ScalarSymbolic <: SymbolicTypeTrait end -Trait indicating a type is symbolic. +Trait indicating a type is a scalar symbolic variable. -See also: [`NotSymbolic`](@ref), [`issymbolic`](@ref) +See also: [`ArraySymbolic`](@ref), [`NotSymbolic`](@ref), [`symbolic_type`](@ref) """ -struct Symbolic <: IsSymbolicTrait end +struct ScalarSymbolic <: SymbolicTypeTrait end """ - struct NotSymbolic <: IsSymbolicTrait end + struct ArraySymbolic <: SymbolicTypeTrait end + +Trait indicating type is a symbolic array or an array of scalar symbolic variables. + +See also: [`ScalarSymbolic`](@ref), [`NotSymbolic`](@ref), [`symbolic_type`](@ref) +""" +struct ArraySymbolic <: SymbolicTypeTrait end + +""" + struct NotSymbolic <: SymbolicTypeTrait end Trait indicating a type is not symbolic. -See also: [`Symbolic`](@ref), [`issymbolic`](@ref) +See also: [`ScalarSymbolic`](@ref), [`ArraySymbolic`](@ref), [`symbolic_type`](@ref) """ -struct NotSymbolic <: IsSymbolicTrait end +struct NotSymbolic <: SymbolicTypeTrait end """ - issymbolic(x) = issymbolic(typeof(x)) - issymbolic(::Type) + symbolic_type(x) = symbolic_type(typeof(x)) + symbolic_type(::Type) + +Get the symbolic type trait of a type. Default to [`NotSymbolic`](@ref) for all types +except `Symbol`. -Check whether a type implements the [`Symbolic`](@ref) trait or not. Default to -[`NotSymbolic`](@ref) for all types except `Symbol`. +See also: [`ScalarSymbolic`](@ref), [`ArraySymbolic`](@ref), [`NotSymbolic`](@ref) """ -issymbolic(x) = issymbolic(typeof(x)) -issymbolic(::Type) = NotSymbolic() -issymbolic(::Type{Symbol}) = Symbolic() +symbolic_type(x) = symbolic_type(typeof(x)) +symbolic_type(::Type) = NotSymbolic() +symbolic_type(::Type{Symbol}) = ScalarSymbolic() diff --git a/test/trait_test.jl b/test/trait_test.jl index 9665509..9527997 100644 --- a/test/trait_test.jl +++ b/test/trait_test.jl @@ -3,12 +3,13 @@ using Symbolics using SymbolicIndexingInterface using Test -@test all(issymbolic.([Int, Float64, String, Bool, UInt, Complex{Float64}]) .== +@test all(symbolic_type.([Int, Float64, String, Bool, UInt, Complex{Float64}]) .== (NotSymbolic(),)) -@test all(issymbolic.([Symbol, SymbolicUtils.BasicSymbolic, Symbolics.Num]) .== - (Symbolic(),)) +@test all(symbolic_type.([Symbol, SymbolicUtils.BasicSymbolic, Symbolics.Num]) .== + (ScalarSymbolic(),)) +@test symbolic_type(Symbolics.Arr) == ArraySymbolic() @variables x -@test issymbolic(x) == Symbolic() +@test symbolic_type(x) == ScalarSymbolic() @variables y[1:3] -@test issymbolic(y) == NotSymbolic() -@test all(issymbolic.(collect(y)) .== (Symbolic(),)) +@test symbolic_type(y) == ArraySymbolic() +@test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),)) From 988573158c28462d13e6a004797cb010243fd435 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 27 Oct 2023 11:27:36 +0530 Subject: [PATCH 16/35] feat: add current_state, refactor observed, update tests --- src/SymbolicIndexingInterface.jl | 3 ++- src/interface.jl | 13 ++++++---- test/example_test.jl | 41 +++++++++++++------------------- 3 files changed, 28 insertions(+), 29 deletions(-) diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 51fa896..2ad8a3e 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -6,7 +6,8 @@ export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type include("trait.jl") export is_variable, has_static_variable, variable_index, is_parameter, parameter_index, - is_independent_variable, is_observed, observed, is_time_dependent, constant_structure + is_independent_variable, current_state, is_observed, observed, is_time_dependent, + constant_structure include("interface.jl") @static if !isdefined(Base, :get_extension) diff --git a/src/interface.jl b/src/interface.jl index 016bdf6..f07de3b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -42,6 +42,13 @@ Check whether the given `sym` is an independent variable in `sys`. """ function is_independent_variable end +""" + current_state(sys, i) + +Return a vector of symbols representing the state variables at timestep `i`. +""" +function current_state end + """ is_observed(sys, sym) @@ -50,14 +57,12 @@ Check whether the given `sym` is an observed value in `sys`. function is_observed end """ - observed(sys, sym, [symbolic_states]) + observed(sys, sym) Return the observed function of the given `sym` in `sys`. The returned function should have the signature `(u, p) -> [values...]` where `u` and `p` is the current state and parameter vector. If `istimedependent(sys) == true`, the function should accept -the current time `t` as its third parameter. If `has_static_variable(sys) == false` then -`observed` must accept a third parameter `symbolic_states` indicating the order of symbolic -variables in `u`. +the current time `t` as its third parameter. See also: [`is_time_dependent`](@ref), [`has_static_variable`](@ref) """ diff --git a/test/example_test.jl b/test/example_test.jl index 7c0d20a..e1e8a61 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -10,7 +10,10 @@ function SymbolicIndexingInterface.variable_index(sys::SymbolCache, sym, t = not if !has_static_variable(sys) && t === nothing error("timestep must be present") end - findfirst(isequal(sym), sys.vars) + findfirst(isequal(sym), current_state(sys, t)) +end +function SymbolicIndexingInterface.current_state(sys::SymbolCache, i) + return has_static_variable(sys) ? sys.vars : circshift(sys.vars, i) end SymbolicIndexingInterface.is_parameter(sys::SymbolCache, sym) = sym in sys.params function SymbolicIndexingInterface.parameter_index(sys::SymbolCache, sym) @@ -22,19 +25,11 @@ end function SymbolicIndexingInterface.is_observed(sys::SymbolCache, sym) is_variable(sys, sym) || is_parameter(sys, sym) || is_independent_variable(sys, sym) end -function SymbolicIndexingInterface.observed(sys::SymbolCache, - sym, - symbolic_states = nothing) - if !has_static_variable(sys) && symbolic_states === nothing - error("Symbolic states must be present") - end - if has_static_variable(sys) - symbolic_states = sys.vars - end - idx = findfirst(isequal(sym), symbolic_states) - if idx !== nothing - return is_time_dependent(sys) ? (u, p, t) -> [u[idx] * i for i in t] : - (u, p) -> [u[idx] * i for i in 1:5] +function SymbolicIndexingInterface.observed(sys::SymbolCache, sym) + if is_variable(sys, sym) + return is_time_dependent(sys) ? + (u, p, t) -> [u[variable_index(sys, sym, i)] * i for i in t] : + (u, p) -> [u[variable_index(sys, sym)] * i for i in 1:5] end idx = parameter_index(sys, sym) if idx !== nothing @@ -87,15 +82,13 @@ sys = SymbolCache(false, [:x, :y, :z], [:a, :b, :c], :t) @test !has_static_variable(sys) for variable in [:x, :y, :z, :a, :b, :c, :t] @test_throws ErrorException variable_index(sys, variable) - @test_throws ErrorException observed(sys, variable) end -@test all(variable_index.((sys,), [:z, :y, :x], 1) .== [3, 2, 1]) +@test all(variable_index.((sys,), [:z, :y, :x], 1) .== [1, 3, 2]) @test all(variable_index.((sys,), [:a, :b, :c, :t], 1) .== nothing) -variable_order = [:x, :y, :z] -@test all(observed(sys, :x, variable_order)(1:3, 4:6, 1:4) .== [1, 2, 3, 4]) -@test all(observed(sys, :y, variable_order)(1:4, 4:6, 1:4) .== [2, 4, 6, 8]) -@test all(observed(sys, :z, variable_order)(1:4, 4:6, 1:4) .== [3, 6, 9, 12]) -@test observed(sys, :a, variable_order)(1:3, 4:6, 1:4) == 4 -@test observed(sys, :b, variable_order)(1:3, 4:6, 1:4) == 5 -@test observed(sys, :c, variable_order)(1:3, 4:6, 1:4) == 6 -@test observed(sys, :t, variable_order)(1:3, 4:6, 1:4) == 1:4 +@test all(observed(sys, :x)(1:3, 4:6, 1:4) .== [2, 6, 3, 8]) +@test all(observed(sys, :y)(1:4, 4:6, 1:4) .== [3, 2, 6, 12]) +@test all(observed(sys, :z)(1:4, 4:6, 1:4) .== [1, 4, 9, 4]) +@test observed(sys, :a)(1:3, 4:6, 1:4) == 4 +@test observed(sys, :b)(1:3, 4:6, 1:4) == 5 +@test observed(sys, :c)(1:3, 4:6, 1:4) == 6 +@test observed(sys, :t)(1:3, 4:6, 1:4) == 1:4 From 15a987b4f10f46aac95c8e54ecf31ccdea90660e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 31 Oct 2023 15:03:02 +0530 Subject: [PATCH 17/35] feat: update interface, tests --- src/interface.jl | 8 +++++--- test/example_test.jl | 49 ++++++++++++++++++++++++-------------------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index f07de3b..7609b65 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -16,7 +16,7 @@ function has_static_variable end variable_index(sys, sym, [i]) Return the index of the given variable `sym` in `sys`, or `nothing` otherwise. If -[`has_static_variable`](@ref) is `false`, this accepts the current timestep as an +[`has_static_variable`](@ref) is `false`, this accepts the current time index as an additional parameter `i`. """ function variable_index end @@ -57,12 +57,14 @@ Check whether the given `sym` is an observed value in `sys`. function is_observed end """ - observed(sys, sym) + observed(sys, sym, [states]) Return the observed function of the given `sym` in `sys`. The returned function should have the signature `(u, p) -> [values...]` where `u` and `p` is the current state and parameter vector. If `istimedependent(sys) == true`, the function should accept -the current time `t` as its third parameter. +the current time `t` as its third parameter. If `constant_structure(sys) == false`, +accept a third parameter which can either be a vector of symbols indicating the order +of states or a time index which identifies the order of states. See also: [`is_time_dependent`](@ref), [`has_static_variable`](@ref) """ diff --git a/test/example_test.jl b/test/example_test.jl index e1e8a61..2adccf0 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -8,7 +8,7 @@ end SymbolicIndexingInterface.is_variable(sys::SymbolCache, sym) = sym in sys.vars function SymbolicIndexingInterface.variable_index(sys::SymbolCache, sym, t = nothing) if !has_static_variable(sys) && t === nothing - error("timestep must be present") + error("time index must be present") end findfirst(isequal(sym), current_state(sys, t)) end @@ -25,18 +25,22 @@ end function SymbolicIndexingInterface.is_observed(sys::SymbolCache, sym) is_variable(sys, sym) || is_parameter(sys, sym) || is_independent_variable(sys, sym) end -function SymbolicIndexingInterface.observed(sys::SymbolCache, sym) +function SymbolicIndexingInterface.observed(sys::SymbolCache, sym, states = nothing) + if !has_static_variable(sys) && states === nothing + error("States required") + end + states = states isa Vector ? states : current_state(sys, states) if is_variable(sys, sym) return is_time_dependent(sys) ? - (u, p, t) -> [u[variable_index(sys, sym, i)] * i for i in t] : - (u, p) -> [u[variable_index(sys, sym)] * i for i in 1:5] + (u, p, t) -> u[findfirst(isequal(sym), states)] : + (u, p) -> u[findfirst(isequal(sym), states)] end idx = parameter_index(sys, sym) if idx !== nothing return is_time_dependent(sys) ? (u, p, t) -> p[idx] : (u, p) -> p[idx] end if is_independent_variable(sys, sym) - return is_time_dependent(sys) ? (u, p, t) -> t : (u, p) -> 1:5 + return is_time_dependent(sys) ? (u, p, t) -> t : (u, p) -> 1 end end SymbolicIndexingInterface.is_time_dependent(sys::SymbolCache) = isequal(sys.indepvar, :t) @@ -56,13 +60,13 @@ sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], :t) @test is_independent_variable(sys, :t) @test all(.!is_independent_variable.((sys,), [:x, :y, :z, :a, :b, :c, :p, :q, :r])) @test all(is_observed.((sys,), [:x, :y, :z, :a, :b, :c, :t])) -@test all(observed(sys, :x)(1:3, 4:6, 1:4) .== [1, 2, 3, 4]) -@test all(observed(sys, :y)(1:4, 4:6, 1:4) .== [2, 4, 6, 8]) -@test all(observed(sys, :z)(1:4, 4:6, 1:4) .== [3, 6, 9, 12]) -@test observed(sys, :a)(1:3, 4:6, 1:4) == 4 -@test observed(sys, :b)(1:3, 4:6, 1:4) == 5 -@test observed(sys, :c)(1:3, 4:6, 1:4) == 6 -@test observed(sys, :t)(1:3, 4:6, 1:4) == 1:4 +@test all(observed(sys, :x)(1:3, 4:6, 1.5) .== 1) +@test all(observed(sys, :y)(1:3, 4:6, 1.5) .== 2) +@test all(observed(sys, :z)(1:3, 4:6, 1.5) .== 3) +@test observed(sys, :a)(1:3, 4:6, 1.5) == 4 +@test observed(sys, :b)(1:3, 4:6, 1.5) == 5 +@test observed(sys, :c)(1:3, 4:6, 1.5) == 6 +@test observed(sys, :t)(1:3, 4:6, 1.5) == 1.5 @test is_time_dependent(sys) @test constant_structure(sys) @test has_static_variable(sys) @@ -70,9 +74,9 @@ sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], :t) sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], nothing) @test !is_time_dependent(sys) -@test all(observed(sys, :x)(1:3, 4:6) .== [1, 2, 3, 4, 5]) -@test all(observed(sys, :y)(1:3, 4:6) .== [2, 4, 6, 8, 10]) -@test all(observed(sys, :z)(1:3, 4:6) .== [3, 6, 9, 12, 15]) +@test all(observed(sys, :x)(1.0:3.0, 4:6) .== 1.0) +@test all(observed(sys, :y)(1.0:3.0, 4:6) .== 2.0) +@test all(observed(sys, :z)(1.0:3.0, 4:6) .== 3.0) @test observed(sys, :a)(1:3, 4:6) == 4 @test observed(sys, :b)(1:3, 4:6) == 5 @test observed(sys, :c)(1:3, 4:6) == 6 @@ -82,13 +86,14 @@ sys = SymbolCache(false, [:x, :y, :z], [:a, :b, :c], :t) @test !has_static_variable(sys) for variable in [:x, :y, :z, :a, :b, :c, :t] @test_throws ErrorException variable_index(sys, variable) + @test_throws ErrorException observed(sys, variable) end @test all(variable_index.((sys,), [:z, :y, :x], 1) .== [1, 3, 2]) @test all(variable_index.((sys,), [:a, :b, :c, :t], 1) .== nothing) -@test all(observed(sys, :x)(1:3, 4:6, 1:4) .== [2, 6, 3, 8]) -@test all(observed(sys, :y)(1:4, 4:6, 1:4) .== [3, 2, 6, 12]) -@test all(observed(sys, :z)(1:4, 4:6, 1:4) .== [1, 4, 9, 4]) -@test observed(sys, :a)(1:3, 4:6, 1:4) == 4 -@test observed(sys, :b)(1:3, 4:6, 1:4) == 5 -@test observed(sys, :c)(1:3, 4:6, 1:4) == 6 -@test observed(sys, :t)(1:3, 4:6, 1:4) == 1:4 +@test all(observed(sys, :x, 2)(1:3, 4:6, 1.5) .== 3) +@test all(observed(sys, :y, 2)(1:3, 4:6, 1.5) .== 1) +@test all(observed(sys, :z, 2)(1:3, 4:6, 1.5) .== 2) +@test observed(sys, :a, 2)(1:3, 4:6, 1.5) == 4 +@test observed(sys, :b, 2)(1:3, 4:6, 1.5) == 5 +@test observed(sys, :c, 2)(1:3, 4:6, 1.5) == 6 +@test observed(sys, :t, 2)(1:3, 4:6, 1.5) == 1.5 From c1821c91e9520557df605cf5988f40f8f849af97 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 3 Nov 2023 17:53:20 +0530 Subject: [PATCH 18/35] feat: add SymbolCache --- src/SymbolicIndexingInterface.jl | 3 +++ src/symbol_cache.jl | 29 +++++++++++++++++++++++++++++ test/example_test.jl | 30 +++++++++++++++--------------- test/runtests.jl | 3 +++ test/symbol_cache_test.jl | 19 +++++++++++++++++++ 5 files changed, 69 insertions(+), 15 deletions(-) create mode 100644 src/symbol_cache.jl create mode 100644 test/symbol_cache_test.jl diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 2ad8a3e..e93a158 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -10,6 +10,9 @@ export is_variable, has_static_variable, variable_index, is_parameter, parameter constant_structure include("interface.jl") +export SymbolCache +include("symbol_cache.jl") + @static if !isdefined(Base, :get_extension) function __init__() @require Symbolics="0c5d862f-8b57-4792-8d23-62f2024744c7" include("../ext/SymbolicIndexingInterfaceSymbolicsExt.jl") diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl new file mode 100644 index 0000000..a2c1d4b --- /dev/null +++ b/src/symbol_cache.jl @@ -0,0 +1,29 @@ +""" + struct SymbolCache{V,P,I} + function SymbolCache(vars, [params, [indepvars]]) + +A struct implementing the symbolic indexing interface for the trivial case +of having a vector of variables, parameters and independent variables. This +struct does not implement `observed`, and `is_observed` returns `false` for +all input symbols. It is considered to be time dependent if it contains +at least one independent variable. +""" +struct SymbolCache{V, P, I} + variables::Vector{V} + parameters::Vector{P} + independent_variables::Vector{I} +end + +function SymbolCache(vars::Vector{V}, params = [], indepvars = []) where {V} + return SymbolCache{V, eltype(params), eltype(indepvars)}(vars, params, indepvars) +end + +is_variable(sc::SymbolCache, sym) = sym in sc.variables +variable_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.variables) +is_parameter(sc::SymbolCache, sym) = sym in sc.parameters +parameter_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.parameters) +is_independent_variable(sc::SymbolCache, sym) = sym in sc.independent_variables +current_state(sc::SymbolCache) = sc.variables +is_observed(sc::SymbolCache, sym) = false +is_time_dependent(sc::SymbolCache) = !isempty(sc.independent_variables) +constant_structure(::SymbolCache) = true diff --git a/test/example_test.jl b/test/example_test.jl index 2adccf0..6dfc03f 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -1,31 +1,31 @@ -struct SymbolCache +struct SystemMockup static::Bool vars::Vector{Symbol} params::Vector{Symbol} indepvar::Union{Symbol, Nothing} end -SymbolicIndexingInterface.is_variable(sys::SymbolCache, sym) = sym in sys.vars -function SymbolicIndexingInterface.variable_index(sys::SymbolCache, sym, t = nothing) +SymbolicIndexingInterface.is_variable(sys::SystemMockup, sym) = sym in sys.vars +function SymbolicIndexingInterface.variable_index(sys::SystemMockup, sym, t = nothing) if !has_static_variable(sys) && t === nothing error("time index must be present") end findfirst(isequal(sym), current_state(sys, t)) end -function SymbolicIndexingInterface.current_state(sys::SymbolCache, i) +function SymbolicIndexingInterface.current_state(sys::SystemMockup, i) return has_static_variable(sys) ? sys.vars : circshift(sys.vars, i) end -SymbolicIndexingInterface.is_parameter(sys::SymbolCache, sym) = sym in sys.params -function SymbolicIndexingInterface.parameter_index(sys::SymbolCache, sym) +SymbolicIndexingInterface.is_parameter(sys::SystemMockup, sym) = sym in sys.params +function SymbolicIndexingInterface.parameter_index(sys::SystemMockup, sym) findfirst(isequal(sym), sys.params) end -function SymbolicIndexingInterface.is_independent_variable(sys::SymbolCache, sym) +function SymbolicIndexingInterface.is_independent_variable(sys::SystemMockup, sym) sys.indepvar !== nothing && isequal(sym, sys.indepvar) end -function SymbolicIndexingInterface.is_observed(sys::SymbolCache, sym) +function SymbolicIndexingInterface.is_observed(sys::SystemMockup, sym) is_variable(sys, sym) || is_parameter(sys, sym) || is_independent_variable(sys, sym) end -function SymbolicIndexingInterface.observed(sys::SymbolCache, sym, states = nothing) +function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = nothing) if !has_static_variable(sys) && states === nothing error("States required") end @@ -43,11 +43,11 @@ function SymbolicIndexingInterface.observed(sys::SymbolCache, sym, states = noth return is_time_dependent(sys) ? (u, p, t) -> t : (u, p) -> 1 end end -SymbolicIndexingInterface.is_time_dependent(sys::SymbolCache) = isequal(sys.indepvar, :t) -SymbolicIndexingInterface.constant_structure(sys::SymbolCache) = true -SymbolicIndexingInterface.has_static_variable(sys::SymbolCache) = sys.static +SymbolicIndexingInterface.is_time_dependent(sys::SystemMockup) = isequal(sys.indepvar, :t) +SymbolicIndexingInterface.constant_structure(sys::SystemMockup) = true +SymbolicIndexingInterface.has_static_variable(sys::SystemMockup) = sys.static -sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], :t) +sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @test all(is_variable.((sys,), [:x, :y, :z])) @test all(.!is_variable.((sys,), [:a, :b, :c, :t, :p, :q, :r])) @@ -71,7 +71,7 @@ sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], :t) @test constant_structure(sys) @test has_static_variable(sys) -sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], nothing) +sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) @test !is_time_dependent(sys) @test all(observed(sys, :x)(1.0:3.0, 4:6) .== 1.0) @@ -82,7 +82,7 @@ sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], nothing) @test observed(sys, :c)(1:3, 4:6) == 6 @test constant_structure(sys) -sys = SymbolCache(false, [:x, :y, :z], [:a, :b, :c], :t) +sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t) @test !has_static_variable(sys) for variable in [:x, :y, :z, :a, :b, :c, :t] @test_throws ErrorException variable_index(sys, variable) diff --git a/test/runtests.jl b/test/runtests.jl index 7cc19d9..f6aa6e3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,3 +7,6 @@ end @testset "Trait test" begin @time include("trait_test.jl") end +@testset "SymbolCache test" begin + @time include("symbol_cache_test.jl") +end diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl new file mode 100644 index 0000000..5da5dbc --- /dev/null +++ b/test/symbol_cache_test.jl @@ -0,0 +1,19 @@ +sc = SymbolCache([:x, :y, :z], [:a, :b], [:t]) + +@test all(is_variable.((sc,), [:x, :y, :z])) +@test all(.!is_variable.((sc,), [:a, :b, :t, :q])) +@test variable_index.((sc,), [:x, :y, :z, :a]) == [1, 2, 3, nothing] +@test all(is_parameter.((sc,), [:a, :b])) +@test all(.!is_parameter.((sc,), [:x, :y, :z, :t, :q])) +@test parameter_index.((sc,), [:a, :b, :x]) == [1, 2, nothing] +@test is_independent_variable(sc, :t) +@test all(.!is_independent_variable.((sc,), [:x, :y, :z, :a, :b, :q])) +@test current_state(sc) == [:x, :y, :z] +@test all(.!is_observed.((sc,), [:x, :y, :z, :a, :b, :t, :q])) +@test is_time_dependent(sc) +@test constant_structure(sc) + +sc = SymbolCache([:x, :y], [:a, :b]) +@test !is_time_dependent(sc) +# make sure the constructor works +@test_nowarn SymbolCache([:x, :y]) From c6a1036b87341fea940e0527e0920b81128f6c65 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 3 Nov 2023 18:02:24 +0530 Subject: [PATCH 19/35] refactor: remove has_static_variable --- src/SymbolicIndexingInterface.jl | 2 +- src/interface.jl | 11 ++--------- test/example_test.jl | 12 +++++------- 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index e93a158..e364d46 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -5,7 +5,7 @@ using Requires export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type include("trait.jl") -export is_variable, has_static_variable, variable_index, is_parameter, parameter_index, +export is_variable, variable_index, is_parameter, parameter_index, is_independent_variable, current_state, is_observed, observed, is_time_dependent, constant_structure include("interface.jl") diff --git a/src/interface.jl b/src/interface.jl index 7609b65..3030c1e 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -5,18 +5,11 @@ Check whether the given `sym` is a variable in `sys`. """ function is_variable end -""" - has_static_variable(sys) - -Check whether the variables in `sys` are stable across time. -""" -function has_static_variable end - """ variable_index(sys, sym, [i]) Return the index of the given variable `sym` in `sys`, or `nothing` otherwise. If -[`has_static_variable`](@ref) is `false`, this accepts the current time index as an +[`constant_structure`](@ref) is `false`, this accepts the current time index as an additional parameter `i`. """ function variable_index end @@ -66,7 +59,7 @@ the current time `t` as its third parameter. If `constant_structure(sys) == fals accept a third parameter which can either be a vector of symbols indicating the order of states or a time index which identifies the order of states. -See also: [`is_time_dependent`](@ref), [`has_static_variable`](@ref) +See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref) """ function observed end diff --git a/test/example_test.jl b/test/example_test.jl index 6dfc03f..f5cb633 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -7,13 +7,13 @@ end SymbolicIndexingInterface.is_variable(sys::SystemMockup, sym) = sym in sys.vars function SymbolicIndexingInterface.variable_index(sys::SystemMockup, sym, t = nothing) - if !has_static_variable(sys) && t === nothing + if !constant_structure(sys) && t === nothing error("time index must be present") end findfirst(isequal(sym), current_state(sys, t)) end function SymbolicIndexingInterface.current_state(sys::SystemMockup, i) - return has_static_variable(sys) ? sys.vars : circshift(sys.vars, i) + return constant_structure(sys) ? sys.vars : circshift(sys.vars, i) end SymbolicIndexingInterface.is_parameter(sys::SystemMockup, sym) = sym in sys.params function SymbolicIndexingInterface.parameter_index(sys::SystemMockup, sym) @@ -26,7 +26,7 @@ function SymbolicIndexingInterface.is_observed(sys::SystemMockup, sym) is_variable(sys, sym) || is_parameter(sys, sym) || is_independent_variable(sys, sym) end function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = nothing) - if !has_static_variable(sys) && states === nothing + if !constant_structure(sys) && states === nothing error("States required") end states = states isa Vector ? states : current_state(sys, states) @@ -44,8 +44,7 @@ function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = not end end SymbolicIndexingInterface.is_time_dependent(sys::SystemMockup) = isequal(sys.indepvar, :t) -SymbolicIndexingInterface.constant_structure(sys::SystemMockup) = true -SymbolicIndexingInterface.has_static_variable(sys::SystemMockup) = sys.static +SymbolicIndexingInterface.constant_structure(sys::SystemMockup) = sys.static sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @@ -69,7 +68,6 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @test observed(sys, :t)(1:3, 4:6, 1.5) == 1.5 @test is_time_dependent(sys) @test constant_structure(sys) -@test has_static_variable(sys) sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) @@ -83,7 +81,7 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) @test constant_structure(sys) sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t) -@test !has_static_variable(sys) +@test !constant_structure(sys) for variable in [:x, :y, :z, :a, :b, :c, :t] @test_throws ErrorException variable_index(sys, variable) @test_throws ErrorException observed(sys, variable) From 916c7f17e935be54a5a893a6f87d8ea00ef212a2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 3 Nov 2023 18:02:33 +0530 Subject: [PATCH 20/35] docs: update api.md --- docs/src/api.md | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 79aea93..ec93493 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,13 +1,29 @@ # Interface Functions ```@docs -isvariable -variableindex -isparameter -parameterindex -isindependent_variable -isobserved +is_variable +variable_index +is_parameter +parameter_index +is_independent_variable +current_state +is_observed observed -istimedependent +is_time_dependent constant_structure ``` + +# Traits + +```@docs +ScalarSymbolic +ArraySymbolic +NotSymbolic +symbolic_type +``` + +# Types + +```@docs +SymbolCache +``` From 9204fc51b701398ae41ad2f596c5993bea505d18 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Nov 2023 12:25:23 +0530 Subject: [PATCH 21/35] feat: add Base.copy method for SymbolCache --- src/symbol_cache.jl | 8 ++++++++ test/symbol_cache_test.jl | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index a2c1d4b..164067c 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -27,3 +27,11 @@ current_state(sc::SymbolCache) = sc.variables is_observed(sc::SymbolCache, sym) = false is_time_dependent(sc::SymbolCache) = !isempty(sc.independent_variables) constant_structure(::SymbolCache) = true + +function Base.copy(sc::SymbolCache) + return SymbolCache( + sc.variables === nothing ? nothing : copy(sc.variables), + sc.parameters === nothing ? nothing : copy(sc.parameters), + sc.independent_variables === nothing ? nothing : copy(sc.independent_variables), + ) +end \ No newline at end of file diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index 5da5dbc..d5b3dfd 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -17,3 +17,8 @@ sc = SymbolCache([:x, :y], [:a, :b]) @test !is_time_dependent(sc) # make sure the constructor works @test_nowarn SymbolCache([:x, :y]) + +sc2 = copy(sc) +@test sc.variables == sc2.variables +@test sc.parameters == sc2.parameters +@test sc.independent_variables == sc2.independent_variables \ No newline at end of file From 8f786f39db1533d5af520e4798c953f93b85da5d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 7 Nov 2023 14:10:04 +0530 Subject: [PATCH 22/35] feat: add functions to get vector of symbols from system --- docs/src/api.md | 4 +++- src/SymbolicIndexingInterface.jl | 6 +++--- src/interface.jl | 26 ++++++++++++++++++++++---- src/symbol_cache.jl | 12 ++++++------ test/example_test.jl | 26 +++++++++++++++++++++++--- test/symbol_cache_test.jl | 7 +++++-- 6 files changed, 62 insertions(+), 19 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index ec93493..f08f7a3 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -3,10 +3,12 @@ ```@docs is_variable variable_index +variable_symbols is_parameter parameter_index +parameter_symbols is_independent_variable -current_state +independent_variable_symbols is_observed observed is_time_dependent diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index e364d46..54bd832 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -5,9 +5,9 @@ using Requires export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type include("trait.jl") -export is_variable, variable_index, is_parameter, parameter_index, - is_independent_variable, current_state, is_observed, observed, is_time_dependent, - constant_structure +export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, + parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed, + observed, is_time_dependent, constant_structure include("interface.jl") export SymbolCache diff --git a/src/interface.jl b/src/interface.jl index 3030c1e..af56adf 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -14,6 +14,15 @@ additional parameter `i`. """ function variable_index end +""" + variable_symbols(sys, [i]) + +Return a vector of the symbolic variables being solved for in the system `sys`. If +`constant_structure(sys) == false` this accepts an additional parameter indicating +the current time index. The returned vector should not be mutated. +""" +function variable_symbols end + """ is_parameter(sys, sym) @@ -28,19 +37,28 @@ Return the index of the given parameter `sym` in `sys`, or `nothing` otherwise. """ function parameter_index end +""" + parameter_symbols(sys) + +Return a vector of the symbolic parameters of the given system `sys`. The returned +vector should not be mutated. +""" +function parameter_symbols end + """ is_independent_variable(sys, sym) -Check whether the given `sym` is an independent variable in `sys`. +Check whether the given `sym` is an independent variable in `sys`. The returned vector +should not be mutated. """ function is_independent_variable end """ - current_state(sys, i) + independent_variable_symbols(sys) -Return a vector of symbols representing the state variables at timestep `i`. +Return a vector of the symbolic independent variables of the given system `sys`. """ -function current_state end +function independent_variable_symbols end """ is_observed(sys, sym) diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index 164067c..06165d2 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -20,18 +20,18 @@ end is_variable(sc::SymbolCache, sym) = sym in sc.variables variable_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.variables) +variable_symbols(sc::SymbolCache, i = nothing) = sc.variables is_parameter(sc::SymbolCache, sym) = sym in sc.parameters parameter_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.parameters) +parameter_symbols(sc::SymbolCache) = sc.parameters is_independent_variable(sc::SymbolCache, sym) = sym in sc.independent_variables -current_state(sc::SymbolCache) = sc.variables +independent_variable_symbols(sc::SymbolCache) = sc.independent_variables is_observed(sc::SymbolCache, sym) = false is_time_dependent(sc::SymbolCache) = !isempty(sc.independent_variables) constant_structure(::SymbolCache) = true function Base.copy(sc::SymbolCache) - return SymbolCache( - sc.variables === nothing ? nothing : copy(sc.variables), + return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables), sc.parameters === nothing ? nothing : copy(sc.parameters), - sc.independent_variables === nothing ? nothing : copy(sc.independent_variables), - ) -end \ No newline at end of file + sc.independent_variables === nothing ? nothing : copy(sc.independent_variables)) +end diff --git a/test/example_test.jl b/test/example_test.jl index f5cb633..260e8aa 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -10,18 +10,26 @@ function SymbolicIndexingInterface.variable_index(sys::SystemMockup, sym, t = no if !constant_structure(sys) && t === nothing error("time index must be present") end - findfirst(isequal(sym), current_state(sys, t)) + findfirst(isequal(sym), variable_symbols(sys, t)) end -function SymbolicIndexingInterface.current_state(sys::SystemMockup, i) +function SymbolicIndexingInterface.variable_symbols(sys::SystemMockup, i = nothing) return constant_structure(sys) ? sys.vars : circshift(sys.vars, i) end SymbolicIndexingInterface.is_parameter(sys::SystemMockup, sym) = sym in sys.params function SymbolicIndexingInterface.parameter_index(sys::SystemMockup, sym) findfirst(isequal(sym), sys.params) end +SymbolicIndexingInterface.parameter_symbols(sys::SystemMockup) = sys.params function SymbolicIndexingInterface.is_independent_variable(sys::SystemMockup, sym) sys.indepvar !== nothing && isequal(sym, sys.indepvar) end +function SymbolicIndexingInterface.independent_variable_symbols(sys::SystemMockup) + if sys.indepvar === nothing + return [] + else + return [sys.indepvar] + end +end function SymbolicIndexingInterface.is_observed(sys::SystemMockup, sym) is_variable(sys, sym) || is_parameter(sys, sym) || is_independent_variable(sys, sym) end @@ -29,7 +37,7 @@ function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = not if !constant_structure(sys) && states === nothing error("States required") end - states = states isa Vector ? states : current_state(sys, states) + states = states isa Vector ? states : variable_symbols(sys, states) if is_variable(sys, sym) return is_time_dependent(sys) ? (u, p, t) -> u[findfirst(isequal(sym), states)] : @@ -68,6 +76,9 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @test observed(sys, :t)(1:3, 4:6, 1.5) == 1.5 @test is_time_dependent(sys) @test constant_structure(sys) +@test variable_symbols(sys) == [:x, :y, :z] +@test parameter_symbols(sys) == [:a, :b, :c] +@test independent_variable_symbols(sys) == [:t] sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) @@ -79,6 +90,9 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) @test observed(sys, :b)(1:3, 4:6) == 5 @test observed(sys, :c)(1:3, 4:6) == 6 @test constant_structure(sys) +@test variable_symbols(sys) == [:x, :y, :z] +@test parameter_symbols(sys) == [:a, :b, :c] +@test independent_variable_symbols(sys) == [] sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t) @test !constant_structure(sys) @@ -95,3 +109,9 @@ end @test observed(sys, :b, 2)(1:3, 4:6, 1.5) == 5 @test observed(sys, :c, 2)(1:3, 4:6, 1.5) == 6 @test observed(sys, :t, 2)(1:3, 4:6, 1.5) == 1.5 +@test_throws Exception variable_symbols(sys) +@test variable_symbols(sys, 1) == [:z, :x, :y] +@test variable_symbols(sys, 2) == [:y, :z, :x] +@test variable_symbols(sys, 3) == [:x, :y, :z] +@test parameter_symbols(sys) == [:a, :b, :c] +@test independent_variable_symbols(sys) == [:t] diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index d5b3dfd..be1b6d9 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -8,17 +8,20 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t]) @test parameter_index.((sc,), [:a, :b, :x]) == [1, 2, nothing] @test is_independent_variable(sc, :t) @test all(.!is_independent_variable.((sc,), [:x, :y, :z, :a, :b, :q])) -@test current_state(sc) == [:x, :y, :z] @test all(.!is_observed.((sc,), [:x, :y, :z, :a, :b, :t, :q])) @test is_time_dependent(sc) @test constant_structure(sc) +@test variable_symbols(sc) == [:x, :y, :z] +@test parameter_symbols(sc) == [:a, :b] +@test independent_variable_symbols(sc) == [:t] sc = SymbolCache([:x, :y], [:a, :b]) @test !is_time_dependent(sc) # make sure the constructor works @test_nowarn SymbolCache([:x, :y]) +@test independent_variable_symbols(sc) == [] sc2 = copy(sc) @test sc.variables == sc2.variables @test sc.parameters == sc2.parameters -@test sc.independent_variables == sc2.independent_variables \ No newline at end of file +@test sc.independent_variables == sc2.independent_variables From 5fb49288f0dad42415c5bdd390a002edbc758f31 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Nov 2023 15:52:01 +0530 Subject: [PATCH 23/35] feat: add getname to API --- docs/src/api.md | 1 + src/SymbolicIndexingInterface.jl | 2 +- src/trait.jl | 8 ++++++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index f08f7a3..ff7855e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -22,6 +22,7 @@ ScalarSymbolic ArraySymbolic NotSymbolic symbolic_type +getname ``` # Types diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 54bd832..8a4cadb 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -2,7 +2,7 @@ module SymbolicIndexingInterface using Requires -export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type +export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, getname include("trait.jl") export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, diff --git a/src/trait.jl b/src/trait.jl index 8b2a95a..ea4df5b 100644 --- a/src/trait.jl +++ b/src/trait.jl @@ -39,3 +39,11 @@ See also: [`ScalarSymbolic`](@ref), [`ArraySymbolic`](@ref), [`NotSymbolic`](@re symbolic_type(x) = symbolic_type(typeof(x)) symbolic_type(::Type) = NotSymbolic() symbolic_type(::Type{Symbol}) = ScalarSymbolic() + + +""" + getname(x)::Symbol + +Get the name of a symbolic variable as a `Symbol` +""" +function getname end \ No newline at end of file From 064d9f041c329fd7b306f78da826bbb4cdc5131f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 20 Nov 2023 13:45:27 +0530 Subject: [PATCH 24/35] refactor: remove SymbolicsExt --- Project.toml | 3 --- ext/SymbolicIndexingInterfaceSymbolicsExt.jl | 14 -------------- src/SymbolicIndexingInterface.jl | 3 +-- 3 files changed, 1 insertion(+), 19 deletions(-) delete mode 100644 ext/SymbolicIndexingInterfaceSymbolicsExt.jl diff --git a/Project.toml b/Project.toml index a139bb4..873c31c 100644 --- a/Project.toml +++ b/Project.toml @@ -8,16 +8,13 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" [weakdeps] SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" -Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [extensions] SymbolicIndexingInterfaceSymbolicUtilsExt = "SymbolicUtils" -SymbolicIndexingInterfaceSymbolicsExt = "Symbolics" [compat] Requires = "1.0" SymbolicUtils = "1.4" -Symbolics = "5.10" julia = "1" [extras] diff --git a/ext/SymbolicIndexingInterfaceSymbolicsExt.jl b/ext/SymbolicIndexingInterfaceSymbolicsExt.jl deleted file mode 100644 index 662d53f..0000000 --- a/ext/SymbolicIndexingInterfaceSymbolicsExt.jl +++ /dev/null @@ -1,14 +0,0 @@ -module SymbolicIndexingInterfaceSymbolicsExt - -using SymbolicIndexingInterface - -@static if isdefined(Base, :get_extension) - using Symbolics -else - using ..Symbolics -end - -SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic() -SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic() - -end diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 8a4cadb..57d0ba6 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -2,7 +2,7 @@ module SymbolicIndexingInterface using Requires -export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, getname +export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type , getname include("trait.jl") export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, @@ -15,7 +15,6 @@ include("symbol_cache.jl") @static if !isdefined(Base, :get_extension) function __init__() - @require Symbolics="0c5d862f-8b57-4792-8d23-62f2024744c7" include("../ext/SymbolicIndexingInterfaceSymbolicsExt.jl") @require SymbolicUtils="d1185830-fcd6-423d-90d6-eec64667417b" include("../ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl") end end From 5b6e191e5b1219191a21cbf63ed5dfaea6cf5876 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 21 Nov 2023 10:53:04 +0530 Subject: [PATCH 25/35] feat: add `hasname` function --- docs/src/api.md | 1 + src/SymbolicIndexingInterface.jl | 2 +- src/trait.jl | 9 +++++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index ff7855e..a222a68 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -22,6 +22,7 @@ ScalarSymbolic ArraySymbolic NotSymbolic symbolic_type +hasname getname ``` diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 57d0ba6..cff0e06 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -2,7 +2,7 @@ module SymbolicIndexingInterface using Requires -export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type , getname +export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname include("trait.jl") export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, diff --git a/src/trait.jl b/src/trait.jl index ea4df5b..bb53ac3 100644 --- a/src/trait.jl +++ b/src/trait.jl @@ -40,6 +40,15 @@ symbolic_type(x) = symbolic_type(typeof(x)) symbolic_type(::Type) = NotSymbolic() symbolic_type(::Type{Symbol}) = ScalarSymbolic() +""" + hasname(x) + +Check whether the given symbolic variable (for which `symbolic_type(x) != NotSymbolic()`) has a valid name as per `getname`. +""" +function hasname end + +hasname(::Symbol) = true +hasname(::Any) = false """ getname(x)::Symbol From 723720504e45332b6bfd5ea082c69fbeb3c62d55 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 24 Nov 2023 15:31:15 +0530 Subject: [PATCH 26/35] refactor: fix SymbolCache SII methods --- src/symbol_cache.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index 06165d2..70f213d 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -18,13 +18,13 @@ function SymbolCache(vars::Vector{V}, params = [], indepvars = []) where {V} return SymbolCache{V, eltype(params), eltype(indepvars)}(vars, params, indepvars) end -is_variable(sc::SymbolCache, sym) = sym in sc.variables +is_variable(sc::SymbolCache, sym) = any(isequal(sym), sc.variables) variable_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.variables) variable_symbols(sc::SymbolCache, i = nothing) = sc.variables -is_parameter(sc::SymbolCache, sym) = sym in sc.parameters +is_parameter(sc::SymbolCache, sym) = any(isequal(sym), sc.parameters) parameter_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.parameters) parameter_symbols(sc::SymbolCache) = sc.parameters -is_independent_variable(sc::SymbolCache, sym) = sym in sc.independent_variables +is_independent_variable(sc::SymbolCache, sym) = any(isequal(sym), sc.independent_variables) independent_variable_symbols(sc::SymbolCache) = sc.independent_variables is_observed(sc::SymbolCache, sym) = false is_time_dependent(sc::SymbolCache) = !isempty(sc.independent_variables) From 6a823db4a21ffa94a0b63df57a6eeff1c0cdb3e8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 24 Nov 2023 15:31:32 +0530 Subject: [PATCH 27/35] docs: improve ArraySymbolic docstring --- src/trait.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/trait.jl b/src/trait.jl index bb53ac3..af74ee1 100644 --- a/src/trait.jl +++ b/src/trait.jl @@ -12,7 +12,10 @@ struct ScalarSymbolic <: SymbolicTypeTrait end """ struct ArraySymbolic <: SymbolicTypeTrait end -Trait indicating type is a symbolic array or an array of scalar symbolic variables. +Trait indicating type is a symbolic array. Calling `collect` on a symbolic array must +return an `AbstractArray` containing `ScalarSymbolic` variables for each element in the +array, in the same shape as the represented array. For example, if `a` is a symbolic array +representing a 2x2 matrix, `collect(a)` must return a 2x2 array of scalar symbolic variables. See also: [`ScalarSymbolic`](@ref), [`NotSymbolic`](@ref), [`symbolic_type`](@ref) """ From b8ff5bacb9323a236aee1bebd6f7fb352302d92b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 24 Nov 2023 15:32:02 +0530 Subject: [PATCH 28/35] feat: add symbolic_container, ParameterIndexingProxy, parameter_values and tests --- src/SymbolicIndexingInterface.jl | 5 ++- src/interface.jl | 39 ++++++++++++++------- src/parameter_indexing_proxy.jl | 49 +++++++++++++++++++++++++++ test/fallback_test.jl | 23 +++++++++++++ test/parameter_indexing_proxy_test.jl | 32 +++++++++++++++++ test/runtests.jl | 6 ++++ 6 files changed, 141 insertions(+), 13 deletions(-) create mode 100644 src/parameter_indexing_proxy.jl create mode 100644 test/fallback_test.jl create mode 100644 test/parameter_indexing_proxy_test.jl diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index cff0e06..6759f38 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -7,12 +7,15 @@ include("trait.jl") export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed, - observed, is_time_dependent, constant_structure + observed, is_time_dependent, constant_structure, symbolic_container include("interface.jl") export SymbolCache include("symbol_cache.jl") +export ParameterIndexingProxy, parameter_values +include("parameter_indexing_proxy.jl") + @static if !isdefined(Base, :get_extension) function __init__() @require SymbolicUtils="d1185830-fcd6-423d-90d6-eec64667417b" include("../ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl") diff --git a/src/interface.jl b/src/interface.jl index af56adf..bcce0b2 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,9 +1,21 @@ +""" + symbolic_container(p) + +Using `p`, return an object that implements the symbolic indexing interface. In case `p` +itself implements the interface, `p` can be returned as-is. All symbolic indexing interface +methods fall back to calling the same method on `symbolic_container(p)`, so this may be +used for trivial implementations of the interface that forward all calls to another object. + +This is also used by [`ParameterIndexingProxy`](@ref) +""" +function symbolic_container end + """ is_variable(sys, sym) Check whether the given `sym` is a variable in `sys`. """ -function is_variable end +is_variable(sys, sym) = is_variable(symbolic_container(sys), sym) """ variable_index(sys, sym, [i]) @@ -12,7 +24,8 @@ Return the index of the given variable `sym` in `sys`, or `nothing` otherwise. I [`constant_structure`](@ref) is `false`, this accepts the current time index as an additional parameter `i`. """ -function variable_index end +variable_index(sys, sym) = variable_index(symbolic_container(sys), sym) +variable_index(sys, sym, i) = variable_index(symbolic_container(sys), sym, i) """ variable_symbols(sys, [i]) @@ -21,21 +34,22 @@ Return a vector of the symbolic variables being solved for in the system `sys`. `constant_structure(sys) == false` this accepts an additional parameter indicating the current time index. The returned vector should not be mutated. """ -function variable_symbols end +variable_symbols(sys) = variable_symbols(symbolic_container(sys)) +variable_symbols(sys, i) = variable_symbols(symbolic_container(sys), i) """ is_parameter(sys, sym) Check whether the given `sym` is a parameter in `sys`. """ -function is_parameter end +is_parameter(sys, sym) = is_parameter(symbolic_container(sys), sym) """ parameter_index(sys, sym) Return the index of the given parameter `sym` in `sys`, or `nothing` otherwise. """ -function parameter_index end +parameter_index(sys, sym) = parameter_index(symbolic_container(sys), sym) """ parameter_symbols(sys) @@ -43,7 +57,7 @@ function parameter_index end Return a vector of the symbolic parameters of the given system `sys`. The returned vector should not be mutated. """ -function parameter_symbols end +parameter_symbols(sys) = parameter_symbols(symbolic_container(sys)) """ is_independent_variable(sys, sym) @@ -51,21 +65,21 @@ function parameter_symbols end Check whether the given `sym` is an independent variable in `sys`. The returned vector should not be mutated. """ -function is_independent_variable end +is_independent_variable(sys, sym) = is_independent_variable(symbolic_container(sys), sym) """ independent_variable_symbols(sys) Return a vector of the symbolic independent variables of the given system `sys`. """ -function independent_variable_symbols end +independent_variable_symbols(sys) = independent_variable_symbols(symbolic_container(sys)) """ is_observed(sys, sym) Check whether the given `sym` is an observed value in `sys`. """ -function is_observed end +is_observed(sys, sym) = is_observed(symbolic_container(sys), sym) """ observed(sys, sym, [states]) @@ -79,14 +93,15 @@ of states or a time index which identifies the order of states. See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref) """ -function observed end +observed(sys, sym) = observed(symbolic_container(sys), sym) +observed(sys, sym, states) = observed(symbolic_container(sys), sym, states) """ is_time_dependent(sys) Check if `sys` has time as (one of) its independent variables. """ -function is_time_dependent end +is_time_dependent(sys) = is_time_dependent(symbolic_container(sys)) """ constant_structure(sys) @@ -94,4 +109,4 @@ function is_time_dependent end Check if `sys` has a constant structure. Constant structure systems do not change the number of variables or parameters over time. """ -function constant_structure end +constant_structure(sys) = constant_structure(symbolic_container(sys)) diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl new file mode 100644 index 0000000..50e1efd --- /dev/null +++ b/src/parameter_indexing_proxy.jl @@ -0,0 +1,49 @@ +""" + parameter_values(p) + +Return an indexable collection containing the value of each parameter in `p`. +""" +function parameter_values end + +""" + struct ParameterIndexingProxy end + ParameterIndexingProxy(p) + +A wrapper struct that allows symbolic indexing of parameters. The wrapped object `p` +must implement [`symbolic_container`](@ref) and [`parameter_values`](@ref). Indexing +of parameters using numeric indices is also permitted. +""" +struct ParameterIndexingProxy{T} + wrapped::T +end + +function Base.getindex(p::ParameterIndexingProxy, args...) + symtype = symbolic_type(first(args)) + elsymtype = symbolic_type(eltype(first(args))) + + if symtype != NotSymbolic() + getindex(p, symtype, args...) + else + getindex(p, elsymtype, args...) + end +end + +function Base.getindex(p::ParameterIndexingProxy, ::NotSymbolic, args) + parameter_values(p.wrapped)[args...] +end + +function Base.getindex(p::ParameterIndexingProxy, ::ScalarSymbolic, sym) + sc = symbolic_container(p.wrapped) + if is_parameter(sc, sym) + return parameter_values(p.wrapped)[parameter_index(sc, sym)] + end + error("Parameter indexing error: $sym is not a parameter") +end + +function Base.getindex(p::ParameterIndexingProxy, ::ScalarSymbolic, sym::Union{AbstractArray,Tuple}) + return getindex.((p,), sym) +end + +function Base.getindex(p::ParameterIndexingProxy, ::ArraySymbolic, sym) + return getindex(p, collect(sym)) +end \ No newline at end of file diff --git a/test/fallback_test.jl b/test/fallback_test.jl new file mode 100644 index 0000000..172d8ee --- /dev/null +++ b/test/fallback_test.jl @@ -0,0 +1,23 @@ +using SymbolicIndexingInterface + +struct Wrapper{W} + wrapped::W +end + +SymbolicIndexingInterface.symbolic_container(w::Wrapper) = w.wrapped + +sc = SymbolCache([:x, :y, :z], [:a, :b], [:t]) +sys = Wrapper(sc) + +all_syms = [:x, :y, :z, :a, :b, :t] +@test is_variable.((sys,), all_syms) == is_variable.((sc,), all_syms) +@test variable_index.((sys,), all_syms) == variable_index.((sc,), all_syms) +@test is_parameter.((sys,), all_syms) == is_parameter.((sc,), all_syms) +@test parameter_index.((sys,), all_syms) == parameter_index.((sc,), all_syms) +@test is_independent_variable.((sys,), all_syms) == is_independent_variable.((sc,), all_syms) +@test is_observed.((sys,), all_syms) == is_observed.((sc,), all_syms) +@test is_time_dependent(sys) == is_time_dependent(sc) +@test constant_structure(sys) == constant_structure(sc) +@test variable_symbols(sys) == variable_symbols(sc) +@test parameter_symbols(sys) == parameter_symbols(sc) +@test independent_variable_symbols(sys) == independent_variable_symbols(sc) diff --git a/test/parameter_indexing_proxy_test.jl b/test/parameter_indexing_proxy_test.jl new file mode 100644 index 0000000..642d14d --- /dev/null +++ b/test/parameter_indexing_proxy_test.jl @@ -0,0 +1,32 @@ +using SymbolicIndexingInterface +using Symbolics + +struct FakeProblem{S,P} + sys::S + p::P +end + +SymbolicIndexingInterface.symbolic_container(fp::FakeProblem) = fp.sys +SymbolicIndexingInterface.parameter_values(fp::FakeProblem) = fp.p + +@variables a[1:2] b +sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t]) + +for p in ([1.0, 2.0, 3.0], (1.0, 2.0, 3.0), [1.0 2.0 3.0]) + fp = FakeProblem(sys, p) + pip = ParameterIndexingProxy(fp) + # numeric indexing still works + for i in eachindex(p) + @test pip[i] == p[i] + end + # index with individual symbols + for (i, sym) in enumerate(parameter_symbols(fp)) + @test pip[sym] == p[i] + end + # index with array of symbols + @test pip[parameter_symbols(fp)] == vec(collect(p)) + # index with tuple of symbols + @test pip[Tuple(parameter_symbols(fp))] == Tuple(p) + # index with symbolic array + @test pip[a] == collect(p)[1:2] +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index f6aa6e3..250865a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,3 +10,9 @@ end @testset "SymbolCache test" begin @time include("symbol_cache_test.jl") end +@testset "Fallback test" begin + @time include("fallback_test.jl") +end +@testset "Parameter indexing proxy test" begin + @time include("parameter_indexing_proxy_test.jl") +end \ No newline at end of file From 4e75bff3bcfb6ffacbe68da4d20caea6b43186ab Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 24 Nov 2023 15:36:36 +0530 Subject: [PATCH 29/35] refactor: add PARAMETER_INDEXING_PROXY_PROPERTY_NAME --- src/parameter_indexing_proxy.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl index 50e1efd..1b70735 100644 --- a/src/parameter_indexing_proxy.jl +++ b/src/parameter_indexing_proxy.jl @@ -1,3 +1,5 @@ +const PARAMETER_INDEXING_PROXY_PROPERTY_NAME = :ps + """ parameter_values(p) From a110ee4c3495c2322e1fb66941fd230908bb597d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 28 Nov 2023 11:05:15 +0530 Subject: [PATCH 30/35] refactor: remove ParameterIndexingProxy, add get_p and set_p functions --- src/SymbolicIndexingInterface.jl | 4 +- src/parameter_indexing.jl | 89 +++++++++++++++++++++++++++ src/parameter_indexing_proxy.jl | 51 --------------- test/parameter_indexing_proxy_test.jl | 32 ---------- test/parameter_indexing_test.jl | 23 +++++++ test/runtests.jl | 4 +- 6 files changed, 116 insertions(+), 87 deletions(-) create mode 100644 src/parameter_indexing.jl delete mode 100644 src/parameter_indexing_proxy.jl delete mode 100644 test/parameter_indexing_proxy_test.jl create mode 100644 test/parameter_indexing_test.jl diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 6759f38..a69897a 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -13,8 +13,8 @@ include("interface.jl") export SymbolCache include("symbol_cache.jl") -export ParameterIndexingProxy, parameter_values -include("parameter_indexing_proxy.jl") +export parameter_values, getp, setp +include("parameter_indexing.jl") @static if !isdefined(Base, :get_extension) function __init__() diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl new file mode 100644 index 0000000..b106c26 --- /dev/null +++ b/src/parameter_indexing.jl @@ -0,0 +1,89 @@ +""" + parameter_values(p) + +Return an indexable collection containing the value of each parameter in `p`. +""" +function parameter_values end + +""" + getp(sys, p) + +Return a function that takes an integrator or solution of `sys`, and returns the value of +the parameter `p`. Requires that the integrator or solution implement +[`parameter_values`](@ref). +""" +function getp(sys, p) + symtype = symbolic_type(p) + elsymtype = symbolic_type(eltype(p)) + if symtype != NotSymbolic() + return _getp(sys, symtype, p) + else + return _getp(sys, elsymtype, p) + end +end + +function _getp(sys, ::NotSymbolic, p) + return function getter(sol) + return parameter_values(sol)[p] + end +end + +function _getp(sys, ::ScalarSymbolic, p) + idx = parameter_index(sys, p) + return function getter(sol) + return parameter_values(sol)[idx] + end +end + +function _getp(sys, ::ScalarSymbolic, p::Union{Tuple,AbstractArray}) + idxs = parameter_index.((sys,), p) + return function getter(sol) + return getindex.((parameter_values(sol),), idxs) + end +end + +function _getp(sys, ::ArraySymbolic, p) + return getp(sys, collect(p)) +end + +""" + setp(sys, p) + +Return a function that takes an integrator of `sys` and a value, and sets the +the parameter `p` to that value. Requires that the integrator implement +[`parameter_values`](@ref) and the returned collection be a mutable reference +to the parameter vector in the integrator. +""" +function setp(sys, p) + symtype = symbolic_type(p) + elsymtype = symbolic_type(eltype(p)) + if symtype != NotSymbolic() + return _setp(sys, symtype, p) + else + return _setp(sys, elsymtype, p) + end +end + +function _setp(sys, ::NotSymbolic, p) + return function setter!(sol, val) + parameter_values(sol)[p] = val + end +end + +function _setp(sys, ::ScalarSymbolic, p) + idx = parameter_index(sys, p) + return function setter!(sol, val) + parameter_values(sol)[idx] = val + end +end + +function _setp(sys, ::ScalarSymbolic, p::Union{Tuple,AbstractArray}) + idxs = parameter_index.((sys,), p) + return function setter!(sol, val) + setindex!.((parameter_values(sol),), val, idxs) + end +end + +function _setp(sys, ::ArraySymbolic, p) + return setp(sys, collect(p)) +end diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl deleted file mode 100644 index 1b70735..0000000 --- a/src/parameter_indexing_proxy.jl +++ /dev/null @@ -1,51 +0,0 @@ -const PARAMETER_INDEXING_PROXY_PROPERTY_NAME = :ps - -""" - parameter_values(p) - -Return an indexable collection containing the value of each parameter in `p`. -""" -function parameter_values end - -""" - struct ParameterIndexingProxy end - ParameterIndexingProxy(p) - -A wrapper struct that allows symbolic indexing of parameters. The wrapped object `p` -must implement [`symbolic_container`](@ref) and [`parameter_values`](@ref). Indexing -of parameters using numeric indices is also permitted. -""" -struct ParameterIndexingProxy{T} - wrapped::T -end - -function Base.getindex(p::ParameterIndexingProxy, args...) - symtype = symbolic_type(first(args)) - elsymtype = symbolic_type(eltype(first(args))) - - if symtype != NotSymbolic() - getindex(p, symtype, args...) - else - getindex(p, elsymtype, args...) - end -end - -function Base.getindex(p::ParameterIndexingProxy, ::NotSymbolic, args) - parameter_values(p.wrapped)[args...] -end - -function Base.getindex(p::ParameterIndexingProxy, ::ScalarSymbolic, sym) - sc = symbolic_container(p.wrapped) - if is_parameter(sc, sym) - return parameter_values(p.wrapped)[parameter_index(sc, sym)] - end - error("Parameter indexing error: $sym is not a parameter") -end - -function Base.getindex(p::ParameterIndexingProxy, ::ScalarSymbolic, sym::Union{AbstractArray,Tuple}) - return getindex.((p,), sym) -end - -function Base.getindex(p::ParameterIndexingProxy, ::ArraySymbolic, sym) - return getindex(p, collect(sym)) -end \ No newline at end of file diff --git a/test/parameter_indexing_proxy_test.jl b/test/parameter_indexing_proxy_test.jl deleted file mode 100644 index 642d14d..0000000 --- a/test/parameter_indexing_proxy_test.jl +++ /dev/null @@ -1,32 +0,0 @@ -using SymbolicIndexingInterface -using Symbolics - -struct FakeProblem{S,P} - sys::S - p::P -end - -SymbolicIndexingInterface.symbolic_container(fp::FakeProblem) = fp.sys -SymbolicIndexingInterface.parameter_values(fp::FakeProblem) = fp.p - -@variables a[1:2] b -sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t]) - -for p in ([1.0, 2.0, 3.0], (1.0, 2.0, 3.0), [1.0 2.0 3.0]) - fp = FakeProblem(sys, p) - pip = ParameterIndexingProxy(fp) - # numeric indexing still works - for i in eachindex(p) - @test pip[i] == p[i] - end - # index with individual symbols - for (i, sym) in enumerate(parameter_symbols(fp)) - @test pip[sym] == p[i] - end - # index with array of symbols - @test pip[parameter_symbols(fp)] == vec(collect(p)) - # index with tuple of symbols - @test pip[Tuple(parameter_symbols(fp))] == Tuple(p) - # index with symbolic array - @test pip[a] == collect(p)[1:2] -end \ No newline at end of file diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl new file mode 100644 index 0000000..a058314 --- /dev/null +++ b/test/parameter_indexing_test.jl @@ -0,0 +1,23 @@ +using SymbolicIndexingInterface +using Symbolics + +struct FakeIntegrator{P} + p::P +end + +SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys +SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p + +@variables a[1:2] b +sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t]) +p = [1.0, 2.0, 3.0] +fi = FakeIntegrator(copy(p)) +for (i, sym) in [(1, a[1]), (2, a[2]), (3, b), ([1,2], a), ([1, 3], [a[1], b]), ((2, 3), (a[2], b))] + get = getp(sys, sym) + set! = setp(sys, sym) + true_value = i isa Tuple ? getindex.((p,), i) : p[i] + @test get(fi) == true_value + set!(fi, 0.5 .* i) + @test get(fi) == 0.5 .* i + set!(fi, true_value) +end diff --git a/test/runtests.jl b/test/runtests.jl index 250865a..b76e653 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,6 @@ end @testset "Fallback test" begin @time include("fallback_test.jl") end -@testset "Parameter indexing proxy test" begin - @time include("parameter_indexing_proxy_test.jl") +@testset "Parameter indexing test" begin + @time include("parameter_indexing_test.jl") end \ No newline at end of file From b4e781c3cfe149fabda9eae5631e880a182618d0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 29 Nov 2023 17:33:42 +0530 Subject: [PATCH 31/35] refactor: rework SymbolCache --- src/symbol_cache.jl | 60 ++++++++++++++++++++++++++++----------- test/symbol_cache_test.jl | 16 +++++++++++ 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index 70f213d..b622d81 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -7,31 +7,59 @@ of having a vector of variables, parameters and independent variables. This struct does not implement `observed`, and `is_observed` returns `false` for all input symbols. It is considered to be time dependent if it contains at least one independent variable. + +The independent variable may be specified as a single symbolic variable instead of an +array containing a single variable if the system has only one independent variable. """ -struct SymbolCache{V, P, I} - variables::Vector{V} - parameters::Vector{P} - independent_variables::Vector{I} +struct SymbolCache{V<:Union{Nothing,AbstractVector}, P<:Union{Nothing,AbstractVector}, I} + variables::V + parameters::P + independent_variables::I end -function SymbolCache(vars::Vector{V}, params = [], indepvars = []) where {V} - return SymbolCache{V, eltype(params), eltype(indepvars)}(vars, params, indepvars) +function SymbolCache(vars = nothing, params = nothing, indepvars = nothing) + return SymbolCache{typeof(vars),typeof(params),typeof(indepvars)}(vars, params, indepvars) end -is_variable(sc::SymbolCache, sym) = any(isequal(sym), sc.variables) -variable_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.variables) -variable_symbols(sc::SymbolCache, i = nothing) = sc.variables -is_parameter(sc::SymbolCache, sym) = any(isequal(sym), sc.parameters) -parameter_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.parameters) -parameter_symbols(sc::SymbolCache) = sc.parameters -is_independent_variable(sc::SymbolCache, sym) = any(isequal(sym), sc.independent_variables) -independent_variable_symbols(sc::SymbolCache) = sc.independent_variables +is_variable(sc::SymbolCache, sym) = sc.variables !== nothing && any(isequal(sym), sc.variables) +variable_index(sc::SymbolCache, sym) = sc.variables === nothing ? nothing : findfirst(isequal(sym), sc.variables) +variable_symbols(sc::SymbolCache, i = nothing) = something(sc.variables, []) +is_parameter(sc::SymbolCache, sym) = sc.parameters !== nothing && any(isequal(sym), sc.parameters) +parameter_index(sc::SymbolCache, sym) = sc.parameters === nothing ? nothing : findfirst(isequal(sym), sc.parameters) +parameter_symbols(sc::SymbolCache) = something(sc.parameters, []) +function is_independent_variable(sc::SymbolCache, sym) + sc.independent_variables === nothing && return false + if symbolic_type(sc.independent_variables) == NotSymbolic() + return any(isequal(sym), sc.independent_variables) + elseif symbolic_type(sc.independent_variables) == ScalarSymbolic() + return sym == sc.independent_variables + else + return any(isequal(sym), collect(sc.independent_variables)) + end +end +function independent_variable_symbols(sc::SymbolCache) + sc.independent_variables === nothing && return [] + if symbolic_type(sc.independent_variables) == NotSymbolic() + return sc.independent_variables + elseif symbolic_type(sc.independent_variables) == ScalarSymbolic() + return [sc.independent_variables] + else + return collect(sc.independent_variables) + end +end is_observed(sc::SymbolCache, sym) = false -is_time_dependent(sc::SymbolCache) = !isempty(sc.independent_variables) +function is_time_dependent(sc::SymbolCache) + sc.independent_variables === nothing && return false + if symbolic_type(sc.independent_variables) == NotSymbolic() + return !isempty(sc.independent_variables) + else + return true + end +end constant_structure(::SymbolCache) = true function Base.copy(sc::SymbolCache) return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables), sc.parameters === nothing ? nothing : copy(sc.parameters), - sc.independent_variables === nothing ? nothing : copy(sc.independent_variables)) + sc.independent_variables isa AbstractArray ? copy(sc.independent_variables) : sc.independent_variables) end diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index be1b6d9..8754b4f 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -19,7 +19,23 @@ sc = SymbolCache([:x, :y], [:a, :b]) @test !is_time_dependent(sc) # make sure the constructor works @test_nowarn SymbolCache([:x, :y]) + +sc = SymbolCache() +@test all(.!is_variable.((sc,), [:x, :y, :a, :b, :t])) +@test all(variable_index.((sc,), [:x, :y, :a, :b, :t]) .== nothing) +@test variable_symbols(sc) == [] +@test all(.!is_parameter.((sc,), [:x, :y, :a, :b, :t])) +@test all(parameter_index.((sc,), [:x, :y, :a, :b, :t]) .== nothing) +@test parameter_symbols(sc) == [] +@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b, :t])) @test independent_variable_symbols(sc) == [] +@test !is_time_dependent(sc) + +sc = SymbolCache(nothing, nothing, :t) +@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b])) +@test is_independent_variable(sc, :t) +@test independent_variable_symbols(sc) == [:t] +@test is_time_dependent(sc) sc2 = copy(sc) @test sc.variables == sc2.variables From 2f13f30e6ef6729053a771c679bf4d04d92fc377 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 11 Dec 2023 16:24:01 +0530 Subject: [PATCH 32/35] docs: add a couple functions to api.md --- docs/src/api.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index a222a68..065a104 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,6 +1,7 @@ # Interface Functions ```@docs +symbolic_container is_variable variable_index variable_symbols @@ -13,6 +14,9 @@ is_observed observed is_time_dependent constant_structure +parameter_values +getp +setp ``` # Traits From 1e5450b160c619d9adb259cb4a7307cb36adfa32 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 11 Dec 2023 17:06:15 +0530 Subject: [PATCH 33/35] test: move Symbolics tests to Symbolics downstream --- .github/workflows/Downstream.yml | 1 + test/parameter_indexing_test.jl | 8 +++----- test/trait_test.jl | 11 +---------- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 1953828..531d600 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -18,6 +18,7 @@ jobs: os: [ubuntu-latest] package: - {user: SciML, repo: RecursiveArrayTools.jl, group: All} + - {user: JuliaSymbolics, repo: Symbolics.jl, group: SymbolicIndexingInterface} steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index a058314..66b1893 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -1,5 +1,4 @@ using SymbolicIndexingInterface -using Symbolics struct FakeIntegrator{P} p::P @@ -8,11 +7,10 @@ end SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p -@variables a[1:2] b -sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t]) -p = [1.0, 2.0, 3.0] +sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) +p = [1.0, 2.0] fi = FakeIntegrator(copy(p)) -for (i, sym) in [(1, a[1]), (2, a[2]), (3, b), ([1,2], a), ([1, 3], [a[1], b]), ((2, 3), (a[2], b))] +for (i, sym) in [(1, :a), (2, :b), ([1, 2], [:a, :b]), ((1, 2), (:a, :b))] get = getp(sys, sym) set! = setp(sys, sym) true_value = i isa Tuple ? getindex.((p,), i) : p[i] diff --git a/test/trait_test.jl b/test/trait_test.jl index 9527997..de7eb85 100644 --- a/test/trait_test.jl +++ b/test/trait_test.jl @@ -1,15 +1,6 @@ -using SymbolicUtils -using Symbolics using SymbolicIndexingInterface using Test @test all(symbolic_type.([Int, Float64, String, Bool, UInt, Complex{Float64}]) .== (NotSymbolic(),)) -@test all(symbolic_type.([Symbol, SymbolicUtils.BasicSymbolic, Symbolics.Num]) .== - (ScalarSymbolic(),)) -@test symbolic_type(Symbolics.Arr) == ArraySymbolic() -@variables x -@test symbolic_type(x) == ScalarSymbolic() -@variables y[1:3] -@test symbolic_type(y) == ArraySymbolic() -@test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),)) +@test symbolic_type(Symbol) == ScalarSymbolic() From b243589b097c19d880b470ec33629bc99ebcb472 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 11 Dec 2023 17:08:45 +0530 Subject: [PATCH 34/35] refactor: format --- src/parameter_indexing.jl | 4 ++-- src/symbol_cache.jl | 29 ++++++++++++++++++++++------- src/trait.jl | 2 +- test/fallback_test.jl | 3 ++- test/runtests.jl | 2 +- 5 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index b106c26..62608d8 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -35,7 +35,7 @@ function _getp(sys, ::ScalarSymbolic, p) end end -function _getp(sys, ::ScalarSymbolic, p::Union{Tuple,AbstractArray}) +function _getp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray}) idxs = parameter_index.((sys,), p) return function getter(sol) return getindex.((parameter_values(sol),), idxs) @@ -77,7 +77,7 @@ function _setp(sys, ::ScalarSymbolic, p) end end -function _setp(sys, ::ScalarSymbolic, p::Union{Tuple,AbstractArray}) +function _setp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray}) idxs = parameter_index.((sys,), p) return function setter!(sol, val) setindex!.((parameter_values(sol),), val, idxs) diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index b622d81..c2d89fa 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -11,21 +11,35 @@ at least one independent variable. The independent variable may be specified as a single symbolic variable instead of an array containing a single variable if the system has only one independent variable. """ -struct SymbolCache{V<:Union{Nothing,AbstractVector}, P<:Union{Nothing,AbstractVector}, I} +struct SymbolCache{ + V <: Union{Nothing, AbstractVector}, + P <: Union{Nothing, AbstractVector}, + I, +} variables::V parameters::P independent_variables::I end function SymbolCache(vars = nothing, params = nothing, indepvars = nothing) - return SymbolCache{typeof(vars),typeof(params),typeof(indepvars)}(vars, params, indepvars) + return SymbolCache{typeof(vars), typeof(params), typeof(indepvars)}(vars, + params, + indepvars) end -is_variable(sc::SymbolCache, sym) = sc.variables !== nothing && any(isequal(sym), sc.variables) -variable_index(sc::SymbolCache, sym) = sc.variables === nothing ? nothing : findfirst(isequal(sym), sc.variables) +function is_variable(sc::SymbolCache, sym) + sc.variables !== nothing && any(isequal(sym), sc.variables) +end +function variable_index(sc::SymbolCache, sym) + sc.variables === nothing ? nothing : findfirst(isequal(sym), sc.variables) +end variable_symbols(sc::SymbolCache, i = nothing) = something(sc.variables, []) -is_parameter(sc::SymbolCache, sym) = sc.parameters !== nothing && any(isequal(sym), sc.parameters) -parameter_index(sc::SymbolCache, sym) = sc.parameters === nothing ? nothing : findfirst(isequal(sym), sc.parameters) +function is_parameter(sc::SymbolCache, sym) + sc.parameters !== nothing && any(isequal(sym), sc.parameters) +end +function parameter_index(sc::SymbolCache, sym) + sc.parameters === nothing ? nothing : findfirst(isequal(sym), sc.parameters) +end parameter_symbols(sc::SymbolCache) = something(sc.parameters, []) function is_independent_variable(sc::SymbolCache, sym) sc.independent_variables === nothing && return false @@ -61,5 +75,6 @@ constant_structure(::SymbolCache) = true function Base.copy(sc::SymbolCache) return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables), sc.parameters === nothing ? nothing : copy(sc.parameters), - sc.independent_variables isa AbstractArray ? copy(sc.independent_variables) : sc.independent_variables) + sc.independent_variables isa AbstractArray ? copy(sc.independent_variables) : + sc.independent_variables) end diff --git a/src/trait.jl b/src/trait.jl index af74ee1..6660e91 100644 --- a/src/trait.jl +++ b/src/trait.jl @@ -58,4 +58,4 @@ hasname(::Any) = false Get the name of a symbolic variable as a `Symbol` """ -function getname end \ No newline at end of file +function getname end diff --git a/test/fallback_test.jl b/test/fallback_test.jl index 172d8ee..4c354bf 100644 --- a/test/fallback_test.jl +++ b/test/fallback_test.jl @@ -14,7 +14,8 @@ all_syms = [:x, :y, :z, :a, :b, :t] @test variable_index.((sys,), all_syms) == variable_index.((sc,), all_syms) @test is_parameter.((sys,), all_syms) == is_parameter.((sc,), all_syms) @test parameter_index.((sys,), all_syms) == parameter_index.((sc,), all_syms) -@test is_independent_variable.((sys,), all_syms) == is_independent_variable.((sc,), all_syms) +@test is_independent_variable.((sys,), all_syms) == + is_independent_variable.((sc,), all_syms) @test is_observed.((sys,), all_syms) == is_observed.((sc,), all_syms) @test is_time_dependent(sys) == is_time_dependent(sc) @test constant_structure(sys) == constant_structure(sc) diff --git a/test/runtests.jl b/test/runtests.jl index b76e653..20538d7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,4 +15,4 @@ end end @testset "Parameter indexing test" begin @time include("parameter_indexing_test.jl") -end \ No newline at end of file +end From 7a5c62fde3c89079e7bdf8a75772ebb2c75a782f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 11 Dec 2023 17:26:31 +0530 Subject: [PATCH 35/35] refactor: remove SymbolicUtils extension --- Project.toml | 15 +-------------- ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl | 14 -------------- src/SymbolicIndexingInterface.jl | 8 -------- 3 files changed, 1 insertion(+), 36 deletions(-) delete mode 100644 ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl diff --git a/Project.toml b/Project.toml index 873c31c..82ee6fd 100644 --- a/Project.toml +++ b/Project.toml @@ -3,24 +3,11 @@ uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" authors = ["Aayush Sabharwal and contributors"] version = "0.3.0" -[deps] -Requires = "ae029012-a4dd-5104-9daa-d747884805df" - -[weakdeps] -SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" - -[extensions] -SymbolicIndexingInterfaceSymbolicUtilsExt = "SymbolicUtils" - [compat] -Requires = "1.0" -SymbolicUtils = "1.4" julia = "1" [extras] -SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" -Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["SymbolicUtils", "Symbolics","Test"] +test = ["Test"] diff --git a/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl b/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl deleted file mode 100644 index 0b202b9..0000000 --- a/ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl +++ /dev/null @@ -1,14 +0,0 @@ -module SymbolicIndexingInterfaceSymbolicUtilsExt - -using SymbolicIndexingInterface -@static if isdefined(Base, :get_extension) - using SymbolicUtils -else - using ..SymbolicUtils -end - -function SymbolicIndexingInterface.symbolic_type(::Type{<:SymbolicUtils.BasicSymbolic}) - ScalarSymbolic() -end - -end diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index a69897a..7a33e0d 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -1,7 +1,5 @@ module SymbolicIndexingInterface -using Requires - export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname include("trait.jl") @@ -16,10 +14,4 @@ include("symbol_cache.jl") export parameter_values, getp, setp include("parameter_indexing.jl") -@static if !isdefined(Base, :get_extension) - function __init__() - @require SymbolicUtils="d1185830-fcd6-423d-90d6-eec64667417b" include("../ext/SymbolicIndexingInterfaceSymbolicUtilsExt.jl") - end -end - end