Skip to content

Commit

Permalink
refactor: remove ParameterIndexingProxy, add get_p and set_p functions
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 28, 2023
1 parent 4e75bff commit a110ee4
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 87 deletions.
4 changes: 2 additions & 2 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ include("interface.jl")
export SymbolCache
include("symbol_cache.jl")

export ParameterIndexingProxy, parameter_values
include("parameter_indexing_proxy.jl")
export parameter_values, getp, setp
include("parameter_indexing.jl")

@static if !isdefined(Base, :get_extension)
function __init__()
Expand Down
89 changes: 89 additions & 0 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
parameter_values(p)
Return an indexable collection containing the value of each parameter in `p`.
"""
function parameter_values end

"""
getp(sys, p)
Return a function that takes an integrator or solution of `sys`, and returns the value of
the parameter `p`. Requires that the integrator or solution implement
[`parameter_values`](@ref).
"""
function getp(sys, p)
symtype = symbolic_type(p)
elsymtype = symbolic_type(eltype(p))
if symtype != NotSymbolic()
return _getp(sys, symtype, p)
else
return _getp(sys, elsymtype, p)
end
end

function _getp(sys, ::NotSymbolic, p)
return function getter(sol)
return parameter_values(sol)[p]
end
end

function _getp(sys, ::ScalarSymbolic, p)
idx = parameter_index(sys, p)
return function getter(sol)
return parameter_values(sol)[idx]
end
end

function _getp(sys, ::ScalarSymbolic, p::Union{Tuple,AbstractArray})
idxs = parameter_index.((sys,), p)
return function getter(sol)
return getindex.((parameter_values(sol),), idxs)
end
end

function _getp(sys, ::ArraySymbolic, p)
return getp(sys, collect(p))
end

"""
setp(sys, p)
Return a function that takes an integrator of `sys` and a value, and sets the
the parameter `p` to that value. Requires that the integrator implement
[`parameter_values`](@ref) and the returned collection be a mutable reference
to the parameter vector in the integrator.
"""
function setp(sys, p)
symtype = symbolic_type(p)
elsymtype = symbolic_type(eltype(p))
if symtype != NotSymbolic()
return _setp(sys, symtype, p)
else
return _setp(sys, elsymtype, p)
end
end

function _setp(sys, ::NotSymbolic, p)
return function setter!(sol, val)
parameter_values(sol)[p] = val
end
end

function _setp(sys, ::ScalarSymbolic, p)
idx = parameter_index(sys, p)
return function setter!(sol, val)
parameter_values(sol)[idx] = val
end
end

function _setp(sys, ::ScalarSymbolic, p::Union{Tuple,AbstractArray})
idxs = parameter_index.((sys,), p)
return function setter!(sol, val)
setindex!.((parameter_values(sol),), val, idxs)
end
end

function _setp(sys, ::ArraySymbolic, p)
return setp(sys, collect(p))
end
51 changes: 0 additions & 51 deletions src/parameter_indexing_proxy.jl

This file was deleted.

32 changes: 0 additions & 32 deletions test/parameter_indexing_proxy_test.jl

This file was deleted.

23 changes: 23 additions & 0 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using SymbolicIndexingInterface
using Symbolics

struct FakeIntegrator{P}
p::P
end

SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p

@variables a[1:2] b
sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t])
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(copy(p))
for (i, sym) in [(1, a[1]), (2, a[2]), (3, b), ([1,2], a), ([1, 3], [a[1], b]), ((2, 3), (a[2], b))]
get = getp(sys, sym)
set! = setp(sys, sym)
true_value = i isa Tuple ? getindex.((p,), i) : p[i]
@test get(fi) == true_value
set!(fi, 0.5 .* i)
@test get(fi) == 0.5 .* i
set!(fi, true_value)
end
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ end
@testset "Fallback test" begin
@time include("fallback_test.jl")
end
@testset "Parameter indexing proxy test" begin
@time include("parameter_indexing_proxy_test.jl")
@testset "Parameter indexing test" begin
@time include("parameter_indexing_test.jl")
end

0 comments on commit a110ee4

Please sign in to comment.