From ba71411968068f94081282fd35695abcf99cc40a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Mar 2024 12:52:19 +0530 Subject: [PATCH] feat: add `finalize_parameters_hook!` --- docs/src/api.md | 1 + docs/src/complete_sii.md | 9 +++++++ src/parameter_indexing.jl | 46 ++++++++++++++++++++++++++------- test/parameter_indexing_test.jl | 14 +++++++++- 4 files changed, 59 insertions(+), 11 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index b297681..f0b7ae1 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -35,6 +35,7 @@ observed ```@docs parameter_values set_parameter! +finalize_parameters_hook! getp setp ParameterIndexingProxy diff --git a/docs/src/complete_sii.md b/docs/src/complete_sii.md index a796a9c..3ad42c4 100644 --- a/docs/src/complete_sii.md +++ b/docs/src/complete_sii.md @@ -267,6 +267,15 @@ function SymbolicIndexingInterface.set_state!(integrator::ExampleIntegrator, val end ``` +### Using `finalize_parameters_hook!` + +The function [`finalize_parameters_hook!`](@ref) is called exactly _once_ every time the +function returned by `setp` is called. This allows performing any additional bookkeeping +required when parameter values are updated. [`set_parameter!`](@ref) also allows performing +similar functionality, but is called for every parameter that is updated, instead of just +once. Thus, `finalize_parameters_hook!` is better for expensive computations that can be +performed for a bulk parameter update. + # The `ParameterIndexingProxy` [`ParameterIndexingProxy`](@ref) is a wrapper around another type which implements the diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 09de194..6cbab3f 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -82,6 +82,16 @@ function set_parameter!(sys::AbstractArray, val, idx) end set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx) +""" + finalize_parameters_hook!(prob, p) + +This is a callback run one for each call to the function returned by [`setp`](@ref) +which can be used to update internal data structures when parameters are modified. +This is in contrast to [`set_parameter!`](@ref) which is run once for each parameter +that is updated. +""" +finalize_parameters_hook!(prob, p) = nothing + """ getp(sys, p) @@ -231,22 +241,36 @@ case `parameter_values` cannot return such a mutable reference, or additional ac need to be performed when updating parameters, [`set_parameter!`](@ref) must be implemented. """ -function setp(sys, p) +function setp(sys, p; run_hook = true) symtype = symbolic_type(p) elsymtype = symbolic_type(eltype(p)) - _setp(sys, symtype, elsymtype, p) + return if run_hook + let _setter! = _setp(sys, symtype, elsymtype, p), p = p + function setter!(prob, args...) + res = _setter!(prob, args...) + finalize_parameters_hook!(prob, p) + res + end + end + else + _setp(sys, symtype, elsymtype, p) + end end function _setp(sys, ::NotSymbolic, ::NotSymbolic, p) - return function setter!(sol, val) - set_parameter!(sol, val, p) + return let p = p + function setter!(sol, val) + set_parameter!(sol, val, p) + end end end function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) - return function setter!(sol, val) - set_parameter!(sol, val, idx) + return let idx = idx + function setter!(sol, val) + set_parameter!(sol, val, idx) + end end end @@ -256,13 +280,15 @@ for (t1, t2) in [ (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] @eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2) - setters = setp.((sys,), p) - return function setter!(sol, val) - map((s!, v) -> s!(sol, v), setters, val) + setters = setp.((sys,), p; run_hook = false) + return let setters = setters + function setter!(sol, val) + map((s!, v) -> s!(sol, v), setters, val) + end end end end function _setp(sys, ::ArraySymbolic, ::NotSymbolic, p) - return setp(sys, collect(p)) + return setp(sys, collect(p); run_hook = false) end diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index a32003a..7684659 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -4,6 +4,7 @@ using Test struct FakeIntegrator{S, P} sys::S p::P + counter::Ref{Int} end function Base.getproperty(fi::FakeIntegrator, s::Symbol) @@ -11,10 +12,13 @@ function Base.getproperty(fi::FakeIntegrator, s::Symbol) end SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p +function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator, p) + fi.counter[] += 1 +end sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) p = [1.0, 2.0, 3.0] -fi = FakeIntegrator(sys, copy(p)) +fi = FakeIntegrator(sys, copy(p), Ref(0)) new_p = [4.0, 5.0, 6.0] @test parameter_timeseries(fi) == [0] for (sym, oldval, newval, check_inference) in [ @@ -39,19 +43,25 @@ for (sym, oldval, newval, check_inference) in [ end @test get(fi) == fi.ps[sym] @test get(fi) == oldval + @test fi.counter[] == 0 if check_inference @inferred set!(fi, newval) else set!(fi, newval) end + @test fi.counter[] == 1 + @test get(fi) == newval set!(fi, oldval) @test get(fi) == oldval + @test fi.counter[] == 2 fi.ps[sym] = newval @test get(fi) == newval + @test fi.counter[] == 3 fi.ps[sym] = oldval @test get(fi) == oldval + @test fi.counter[] == 4 if check_inference @inferred get(p) @@ -65,6 +75,8 @@ for (sym, oldval, newval, check_inference) in [ @test get(p) == newval set!(p, oldval) @test get(p) == oldval + @test fi.counter[] == 4 + fi.counter[] = 0 end for (sym, val) in [