From a110ee4c3495c2322e1fb66941fd230908bb597d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 28 Nov 2023 11:05:15 +0530 Subject: [PATCH] refactor: remove ParameterIndexingProxy, add get_p and set_p functions --- src/SymbolicIndexingInterface.jl | 4 +- src/parameter_indexing.jl | 89 +++++++++++++++++++++++++++ src/parameter_indexing_proxy.jl | 51 --------------- test/parameter_indexing_proxy_test.jl | 32 ---------- test/parameter_indexing_test.jl | 23 +++++++ test/runtests.jl | 4 +- 6 files changed, 116 insertions(+), 87 deletions(-) create mode 100644 src/parameter_indexing.jl delete mode 100644 src/parameter_indexing_proxy.jl delete mode 100644 test/parameter_indexing_proxy_test.jl create mode 100644 test/parameter_indexing_test.jl diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 6759f38..a69897a 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -13,8 +13,8 @@ include("interface.jl") export SymbolCache include("symbol_cache.jl") -export ParameterIndexingProxy, parameter_values -include("parameter_indexing_proxy.jl") +export parameter_values, getp, setp +include("parameter_indexing.jl") @static if !isdefined(Base, :get_extension) function __init__() diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl new file mode 100644 index 0000000..b106c26 --- /dev/null +++ b/src/parameter_indexing.jl @@ -0,0 +1,89 @@ +""" + parameter_values(p) + +Return an indexable collection containing the value of each parameter in `p`. +""" +function parameter_values end + +""" + getp(sys, p) + +Return a function that takes an integrator or solution of `sys`, and returns the value of +the parameter `p`. Requires that the integrator or solution implement +[`parameter_values`](@ref). +""" +function getp(sys, p) + symtype = symbolic_type(p) + elsymtype = symbolic_type(eltype(p)) + if symtype != NotSymbolic() + return _getp(sys, symtype, p) + else + return _getp(sys, elsymtype, p) + end +end + +function _getp(sys, ::NotSymbolic, p) + return function getter(sol) + return parameter_values(sol)[p] + end +end + +function _getp(sys, ::ScalarSymbolic, p) + idx = parameter_index(sys, p) + return function getter(sol) + return parameter_values(sol)[idx] + end +end + +function _getp(sys, ::ScalarSymbolic, p::Union{Tuple,AbstractArray}) + idxs = parameter_index.((sys,), p) + return function getter(sol) + return getindex.((parameter_values(sol),), idxs) + end +end + +function _getp(sys, ::ArraySymbolic, p) + return getp(sys, collect(p)) +end + +""" + setp(sys, p) + +Return a function that takes an integrator of `sys` and a value, and sets the +the parameter `p` to that value. Requires that the integrator implement +[`parameter_values`](@ref) and the returned collection be a mutable reference +to the parameter vector in the integrator. +""" +function setp(sys, p) + symtype = symbolic_type(p) + elsymtype = symbolic_type(eltype(p)) + if symtype != NotSymbolic() + return _setp(sys, symtype, p) + else + return _setp(sys, elsymtype, p) + end +end + +function _setp(sys, ::NotSymbolic, p) + return function setter!(sol, val) + parameter_values(sol)[p] = val + end +end + +function _setp(sys, ::ScalarSymbolic, p) + idx = parameter_index(sys, p) + return function setter!(sol, val) + parameter_values(sol)[idx] = val + end +end + +function _setp(sys, ::ScalarSymbolic, p::Union{Tuple,AbstractArray}) + idxs = parameter_index.((sys,), p) + return function setter!(sol, val) + setindex!.((parameter_values(sol),), val, idxs) + end +end + +function _setp(sys, ::ArraySymbolic, p) + return setp(sys, collect(p)) +end diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl deleted file mode 100644 index 1b70735..0000000 --- a/src/parameter_indexing_proxy.jl +++ /dev/null @@ -1,51 +0,0 @@ -const PARAMETER_INDEXING_PROXY_PROPERTY_NAME = :ps - -""" - parameter_values(p) - -Return an indexable collection containing the value of each parameter in `p`. -""" -function parameter_values end - -""" - struct ParameterIndexingProxy end - ParameterIndexingProxy(p) - -A wrapper struct that allows symbolic indexing of parameters. The wrapped object `p` -must implement [`symbolic_container`](@ref) and [`parameter_values`](@ref). Indexing -of parameters using numeric indices is also permitted. -""" -struct ParameterIndexingProxy{T} - wrapped::T -end - -function Base.getindex(p::ParameterIndexingProxy, args...) - symtype = symbolic_type(first(args)) - elsymtype = symbolic_type(eltype(first(args))) - - if symtype != NotSymbolic() - getindex(p, symtype, args...) - else - getindex(p, elsymtype, args...) - end -end - -function Base.getindex(p::ParameterIndexingProxy, ::NotSymbolic, args) - parameter_values(p.wrapped)[args...] -end - -function Base.getindex(p::ParameterIndexingProxy, ::ScalarSymbolic, sym) - sc = symbolic_container(p.wrapped) - if is_parameter(sc, sym) - return parameter_values(p.wrapped)[parameter_index(sc, sym)] - end - error("Parameter indexing error: $sym is not a parameter") -end - -function Base.getindex(p::ParameterIndexingProxy, ::ScalarSymbolic, sym::Union{AbstractArray,Tuple}) - return getindex.((p,), sym) -end - -function Base.getindex(p::ParameterIndexingProxy, ::ArraySymbolic, sym) - return getindex(p, collect(sym)) -end \ No newline at end of file diff --git a/test/parameter_indexing_proxy_test.jl b/test/parameter_indexing_proxy_test.jl deleted file mode 100644 index 642d14d..0000000 --- a/test/parameter_indexing_proxy_test.jl +++ /dev/null @@ -1,32 +0,0 @@ -using SymbolicIndexingInterface -using Symbolics - -struct FakeProblem{S,P} - sys::S - p::P -end - -SymbolicIndexingInterface.symbolic_container(fp::FakeProblem) = fp.sys -SymbolicIndexingInterface.parameter_values(fp::FakeProblem) = fp.p - -@variables a[1:2] b -sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t]) - -for p in ([1.0, 2.0, 3.0], (1.0, 2.0, 3.0), [1.0 2.0 3.0]) - fp = FakeProblem(sys, p) - pip = ParameterIndexingProxy(fp) - # numeric indexing still works - for i in eachindex(p) - @test pip[i] == p[i] - end - # index with individual symbols - for (i, sym) in enumerate(parameter_symbols(fp)) - @test pip[sym] == p[i] - end - # index with array of symbols - @test pip[parameter_symbols(fp)] == vec(collect(p)) - # index with tuple of symbols - @test pip[Tuple(parameter_symbols(fp))] == Tuple(p) - # index with symbolic array - @test pip[a] == collect(p)[1:2] -end \ No newline at end of file diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl new file mode 100644 index 0000000..a058314 --- /dev/null +++ b/test/parameter_indexing_test.jl @@ -0,0 +1,23 @@ +using SymbolicIndexingInterface +using Symbolics + +struct FakeIntegrator{P} + p::P +end + +SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys +SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p + +@variables a[1:2] b +sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t]) +p = [1.0, 2.0, 3.0] +fi = FakeIntegrator(copy(p)) +for (i, sym) in [(1, a[1]), (2, a[2]), (3, b), ([1,2], a), ([1, 3], [a[1], b]), ((2, 3), (a[2], b))] + get = getp(sys, sym) + set! = setp(sys, sym) + true_value = i isa Tuple ? getindex.((p,), i) : p[i] + @test get(fi) == true_value + set!(fi, 0.5 .* i) + @test get(fi) == 0.5 .* i + set!(fi, true_value) +end diff --git a/test/runtests.jl b/test/runtests.jl index 250865a..b76e653 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,6 @@ end @testset "Fallback test" begin @time include("fallback_test.jl") end -@testset "Parameter indexing proxy test" begin - @time include("parameter_indexing_proxy_test.jl") +@testset "Parameter indexing test" begin + @time include("parameter_indexing_test.jl") end \ No newline at end of file