Skip to content

Commit

Permalink
feat: add finalize_parameters_hook!
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Mar 21, 2024
1 parent 584bf3d commit ba71411
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ observed
```@docs
parameter_values
set_parameter!
finalize_parameters_hook!
getp
setp
ParameterIndexingProxy
Expand Down
9 changes: 9 additions & 0 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,15 @@ function SymbolicIndexingInterface.set_state!(integrator::ExampleIntegrator, val
end
```

### Using `finalize_parameters_hook!`

The function [`finalize_parameters_hook!`](@ref) is called exactly _once_ every time the
function returned by `setp` is called. This allows performing any additional bookkeeping
required when parameter values are updated. [`set_parameter!`](@ref) also allows performing
similar functionality, but is called for every parameter that is updated, instead of just
once. Thus, `finalize_parameters_hook!` is better for expensive computations that can be
performed for a bulk parameter update.

# The `ParameterIndexingProxy`

[`ParameterIndexingProxy`](@ref) is a wrapper around another type which implements the
Expand Down
46 changes: 36 additions & 10 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ function set_parameter!(sys::AbstractArray, val, idx)
end
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)

"""
finalize_parameters_hook!(prob, p)
This is a callback run one for each call to the function returned by [`setp`](@ref)
which can be used to update internal data structures when parameters are modified.
This is in contrast to [`set_parameter!`](@ref) which is run once for each parameter
that is updated.
"""
finalize_parameters_hook!(prob, p) = nothing

Check warning on line 93 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L93

Added line #L93 was not covered by tests

"""
getp(sys, p)
Expand Down Expand Up @@ -231,22 +241,36 @@ case `parameter_values` cannot return such a mutable reference, or additional ac
need to be performed when updating parameters, [`set_parameter!`](@ref) must be
implemented.
"""
function setp(sys, p)
function setp(sys, p; run_hook = true)

Check warning on line 244 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L244

Added line #L244 was not covered by tests
symtype = symbolic_type(p)
elsymtype = symbolic_type(eltype(p))
_setp(sys, symtype, elsymtype, p)
return if run_hook
let _setter! = _setp(sys, symtype, elsymtype, p), p = p
function setter!(prob, args...)
res = _setter!(prob, args...)
finalize_parameters_hook!(prob, p)
res

Check warning on line 252 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L247-L252

Added lines #L247 - L252 were not covered by tests
end
end
else
_setp(sys, symtype, elsymtype, p)

Check warning on line 256 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L256

Added line #L256 was not covered by tests
end
end

function _setp(sys, ::NotSymbolic, ::NotSymbolic, p)
return function setter!(sol, val)
set_parameter!(sol, val, p)
return let p = p
function setter!(sol, val)
set_parameter!(sol, val, p)

Check warning on line 263 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L261-L263

Added lines #L261 - L263 were not covered by tests
end
end
end

function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return function setter!(sol, val)
set_parameter!(sol, val, idx)
return let idx = idx
function setter!(sol, val)
set_parameter!(sol, val, idx)

Check warning on line 272 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L270-L272

Added lines #L270 - L272 were not covered by tests
end
end
end

Expand All @@ -256,13 +280,15 @@ for (t1, t2) in [
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
setters = setp.((sys,), p)
return function setter!(sol, val)
map((s!, v) -> s!(sol, v), setters, val)
setters = setp.((sys,), p; run_hook = false)
return let setters = setters
function setter!(sol, val)
map((s!, v) -> s!(sol, v), setters, val)

Check warning on line 286 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L283-L286

Added lines #L283 - L286 were not covered by tests
end
end
end
end

function _setp(sys, ::ArraySymbolic, ::NotSymbolic, p)
return setp(sys, collect(p))
return setp(sys, collect(p); run_hook = false)

Check warning on line 293 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L293

Added line #L293 was not covered by tests
end
14 changes: 13 additions & 1 deletion test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@ using Test
struct FakeIntegrator{S, P}
sys::S
p::P
counter::Ref{Int}
end

function Base.getproperty(fi::FakeIntegrator, s::Symbol)
s === :ps ? ParameterIndexingProxy(fi) : getfield(fi, s)
end
SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p
function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator, p)
fi.counter[] += 1
end

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, copy(p))
fi = FakeIntegrator(sys, copy(p), Ref(0))
new_p = [4.0, 5.0, 6.0]
@test parameter_timeseries(fi) == [0]
for (sym, oldval, newval, check_inference) in [
Expand All @@ -39,19 +43,25 @@ for (sym, oldval, newval, check_inference) in [
end
@test get(fi) == fi.ps[sym]
@test get(fi) == oldval
@test fi.counter[] == 0
if check_inference
@inferred set!(fi, newval)
else
set!(fi, newval)
end
@test fi.counter[] == 1

@test get(fi) == newval
set!(fi, oldval)
@test get(fi) == oldval
@test fi.counter[] == 2

fi.ps[sym] = newval
@test get(fi) == newval
@test fi.counter[] == 3
fi.ps[sym] = oldval
@test get(fi) == oldval
@test fi.counter[] == 4

if check_inference
@inferred get(p)
Expand All @@ -65,6 +75,8 @@ for (sym, oldval, newval, check_inference) in [
@test get(p) == newval
set!(p, oldval)
@test get(p) == oldval
@test fi.counter[] == 4
fi.counter[] = 0
end

for (sym, val) in [
Expand Down

0 comments on commit ba71411

Please sign in to comment.