-
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: remove ParameterIndexingProxy, add get_p and set_p functions
- Loading branch information
1 parent
4e75bff
commit a110ee4
Showing
6 changed files
with
116 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
""" | ||
parameter_values(p) | ||
Return an indexable collection containing the value of each parameter in `p`. | ||
""" | ||
function parameter_values end | ||
|
||
""" | ||
getp(sys, p) | ||
Return a function that takes an integrator or solution of `sys`, and returns the value of | ||
the parameter `p`. Requires that the integrator or solution implement | ||
[`parameter_values`](@ref). | ||
""" | ||
function getp(sys, p) | ||
symtype = symbolic_type(p) | ||
elsymtype = symbolic_type(eltype(p)) | ||
if symtype != NotSymbolic() | ||
return _getp(sys, symtype, p) | ||
else | ||
return _getp(sys, elsymtype, p) | ||
end | ||
end | ||
|
||
function _getp(sys, ::NotSymbolic, p) | ||
return function getter(sol) | ||
return parameter_values(sol)[p] | ||
end | ||
end | ||
|
||
function _getp(sys, ::ScalarSymbolic, p) | ||
idx = parameter_index(sys, p) | ||
return function getter(sol) | ||
return parameter_values(sol)[idx] | ||
end | ||
end | ||
|
||
function _getp(sys, ::ScalarSymbolic, p::Union{Tuple,AbstractArray}) | ||
idxs = parameter_index.((sys,), p) | ||
return function getter(sol) | ||
return getindex.((parameter_values(sol),), idxs) | ||
end | ||
end | ||
|
||
function _getp(sys, ::ArraySymbolic, p) | ||
return getp(sys, collect(p)) | ||
end | ||
|
||
""" | ||
setp(sys, p) | ||
Return a function that takes an integrator of `sys` and a value, and sets the | ||
the parameter `p` to that value. Requires that the integrator implement | ||
[`parameter_values`](@ref) and the returned collection be a mutable reference | ||
to the parameter vector in the integrator. | ||
""" | ||
function setp(sys, p) | ||
symtype = symbolic_type(p) | ||
elsymtype = symbolic_type(eltype(p)) | ||
if symtype != NotSymbolic() | ||
return _setp(sys, symtype, p) | ||
else | ||
return _setp(sys, elsymtype, p) | ||
end | ||
end | ||
|
||
function _setp(sys, ::NotSymbolic, p) | ||
return function setter!(sol, val) | ||
parameter_values(sol)[p] = val | ||
end | ||
end | ||
|
||
function _setp(sys, ::ScalarSymbolic, p) | ||
idx = parameter_index(sys, p) | ||
return function setter!(sol, val) | ||
parameter_values(sol)[idx] = val | ||
end | ||
end | ||
|
||
function _setp(sys, ::ScalarSymbolic, p::Union{Tuple,AbstractArray}) | ||
idxs = parameter_index.((sys,), p) | ||
return function setter!(sol, val) | ||
setindex!.((parameter_values(sol),), val, idxs) | ||
end | ||
end | ||
|
||
function _setp(sys, ::ArraySymbolic, p) | ||
return setp(sys, collect(p)) | ||
end |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
using SymbolicIndexingInterface | ||
using Symbolics | ||
|
||
struct FakeIntegrator{P} | ||
p::P | ||
end | ||
|
||
SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys | ||
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p | ||
|
||
@variables a[1:2] b | ||
sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t]) | ||
p = [1.0, 2.0, 3.0] | ||
fi = FakeIntegrator(copy(p)) | ||
for (i, sym) in [(1, a[1]), (2, a[2]), (3, b), ([1,2], a), ([1, 3], [a[1], b]), ((2, 3), (a[2], b))] | ||
get = getp(sys, sym) | ||
set! = setp(sys, sym) | ||
true_value = i isa Tuple ? getindex.((p,), i) : p[i] | ||
@test get(fi) == true_value | ||
set!(fi, 0.5 .* i) | ||
@test get(fi) == 0.5 .* i | ||
set!(fi, true_value) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters