Skip to content

Commit

Permalink
Merge pull request #58 from SciML/as/setp-callback
Browse files Browse the repository at this point in the history
feat: add `finalize_parameters_hook!`
  • Loading branch information
ChrisRackauckas authored Mar 25, 2024
2 parents ba967a9 + ba71411 commit b3404cf
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ observed
```@docs
parameter_values
set_parameter!
finalize_parameters_hook!
getp
setp
ParameterIndexingProxy
Expand Down
9 changes: 9 additions & 0 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,15 @@ function SymbolicIndexingInterface.set_state!(integrator::ExampleIntegrator, val
end
```

### Using `finalize_parameters_hook!`

The function [`finalize_parameters_hook!`](@ref) is called exactly _once_ every time the
function returned by `setp` is called. This allows performing any additional bookkeeping
required when parameter values are updated. [`set_parameter!`](@ref) also allows performing
similar functionality, but is called for every parameter that is updated, instead of just
once. Thus, `finalize_parameters_hook!` is better for expensive computations that can be
performed for a bulk parameter update.

# The `ParameterIndexingProxy`

[`ParameterIndexingProxy`](@ref) is a wrapper around another type which implements the
Expand Down
46 changes: 36 additions & 10 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ function set_parameter!(sys::AbstractArray, val, idx)
end
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)

"""
finalize_parameters_hook!(prob, p)
This is a callback run one for each call to the function returned by [`setp`](@ref)
which can be used to update internal data structures when parameters are modified.
This is in contrast to [`set_parameter!`](@ref) which is run once for each parameter
that is updated.
"""
finalize_parameters_hook!(prob, p) = nothing

"""
getp(sys, p)
Expand Down Expand Up @@ -231,22 +241,36 @@ case `parameter_values` cannot return such a mutable reference, or additional ac
need to be performed when updating parameters, [`set_parameter!`](@ref) must be
implemented.
"""
function setp(sys, p)
function setp(sys, p; run_hook = true)
symtype = symbolic_type(p)
elsymtype = symbolic_type(eltype(p))
_setp(sys, symtype, elsymtype, p)
return if run_hook
let _setter! = _setp(sys, symtype, elsymtype, p), p = p
function setter!(prob, args...)
res = _setter!(prob, args...)
finalize_parameters_hook!(prob, p)
res
end
end
else
_setp(sys, symtype, elsymtype, p)
end
end

function _setp(sys, ::NotSymbolic, ::NotSymbolic, p)
return function setter!(sol, val)
set_parameter!(sol, val, p)
return let p = p
function setter!(sol, val)
set_parameter!(sol, val, p)
end
end
end

function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return function setter!(sol, val)
set_parameter!(sol, val, idx)
return let idx = idx
function setter!(sol, val)
set_parameter!(sol, val, idx)
end
end
end

Expand All @@ -256,13 +280,15 @@ for (t1, t2) in [
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
setters = setp.((sys,), p)
return function setter!(sol, val)
map((s!, v) -> s!(sol, v), setters, val)
setters = setp.((sys,), p; run_hook = false)
return let setters = setters
function setter!(sol, val)
map((s!, v) -> s!(sol, v), setters, val)
end
end
end
end

function _setp(sys, ::ArraySymbolic, ::NotSymbolic, p)
return setp(sys, collect(p))
return setp(sys, collect(p); run_hook = false)
end
14 changes: 13 additions & 1 deletion test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@ using Test
struct FakeIntegrator{S, P}
sys::S
p::P
counter::Ref{Int}
end

function Base.getproperty(fi::FakeIntegrator, s::Symbol)
s === :ps ? ParameterIndexingProxy(fi) : getfield(fi, s)
end
SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p
function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator, p)
fi.counter[] += 1
end

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, copy(p))
fi = FakeIntegrator(sys, copy(p), Ref(0))
new_p = [4.0, 5.0, 6.0]
@test parameter_timeseries(fi) == [0]
for (sym, oldval, newval, check_inference) in [
Expand All @@ -39,19 +43,25 @@ for (sym, oldval, newval, check_inference) in [
end
@test get(fi) == fi.ps[sym]
@test get(fi) == oldval
@test fi.counter[] == 0
if check_inference
@inferred set!(fi, newval)
else
set!(fi, newval)
end
@test fi.counter[] == 1

@test get(fi) == newval
set!(fi, oldval)
@test get(fi) == oldval
@test fi.counter[] == 2

fi.ps[sym] = newval
@test get(fi) == newval
@test fi.counter[] == 3
fi.ps[sym] = oldval
@test get(fi) == oldval
@test fi.counter[] == 4

if check_inference
@inferred get(p)
Expand All @@ -65,6 +75,8 @@ for (sym, oldval, newval, check_inference) in [
@test get(p) == newval
set!(p, oldval)
@test get(p) == oldval
@test fi.counter[] == 4
fi.counter[] = 0
end

for (sym, val) in [
Expand Down

0 comments on commit b3404cf

Please sign in to comment.