diff --git a/docs/src/api.md b/docs/src/api.md index 4811a57d..dfaadfbf 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -75,6 +75,7 @@ set_parameter! finalize_parameters_hook! getp setp +setp_oop ParameterIndexingProxy ``` diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 5e302e84..a368dc08 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -32,7 +32,7 @@ include("value_provider_interface.jl") export ParameterTimeseriesCollection include("parameter_timeseries_collection.jl") -export getp, setp +export getp, setp, setp_oop include("parameter_indexing.jl") export getu, setu diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 8f4d5ea5..1c61f8aa 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -647,7 +647,7 @@ end """ setp(indp, sym) -Return a function that takes an index provider and a value, and sets the parameter `sym` +Return a function that takes a value provider and a value, and sets the parameter `sym` to that value. Note that `sym` can be an index, a symbolic variable, or an array/tuple of the aforementioned. @@ -709,3 +709,69 @@ function _setp(sys, ::ArraySymbolic, ::SymbolicTypeTrait, p) end return setp(sys, collect(p); run_hook = false) end + +""" + setp_oop(indp, sym) + +Return a function which takes a value provider `valp` and a value `val`, and returns +`parameter_values(valp)` with the parameters at `sym` set to `val`. This allows changing +the types of values stored, and leverages [`remake_buffer`](@ref). Note that `sym` can be +an index, a symbolic variable, or an array/tuple of the aforementioned. + +Requires that the value provider implement `parameter_values` and `remake_buffer`. +""" +function setp_oop(indp, sym) + symtype = symbolic_type(sym) + elsymtype = symbolic_type(eltype(sym)) + return _setp_oop(indp, symtype, elsymtype, sym) +end + +struct OOPSetter{I, D} + indp::I + idxs::D +end + +function (os::OOPSetter)(valp, val) + return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,)) +end + +function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray}) + if os.idxs isa Union{Tuple, AbstractArray} + return remake_buffer(os.indp, parameter_values(valp), os.idxs, val) + else + return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,)) + end +end + +function _root_indp(indp) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) && + (sc = symbolic_container(indp)) != indp + return _root_indp(sc) + else + return indp + end +end + +function _setp_oop(indp, ::NotSymbolic, ::NotSymbolic, sym) + return OOPSetter(_root_indp(indp), sym) +end + +function _setp_oop(indp, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) + return OOPSetter(_root_indp(indp), parameter_index(indp, sym)) +end + +for (t1, t2) in [ + (ScalarSymbolic, Any), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}) +] + @eval function _setp_oop(indp, ::NotSymbolic, ::$t1, sym::$t2) + return OOPSetter(_root_indp(indp), parameter_index.((indp,), sym)) + end +end + +function _setp_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym) + if is_parameter(indp, sym) + return OOPSetter(_root_indp(indp), parameter_index(indp, sym)) + end + error("$sym is not a valid parameter") +end diff --git a/src/remake.jl b/src/remake.jl index ea458f6e..4d915ad6 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -1,24 +1,34 @@ """ - remake_buffer(indp, oldbuffer, vals::Dict) - -Return a copy of the buffer `oldbuffer` with values from `vals`. The keys of `vals` -are symbolic variables whose index in the buffer is determined using `indp`. The types of -values in `vals` may not match the types of values stored at the corresponding indexes in -the buffer, in which case the type of the buffer should be promoted accordingly. In -general, this method should attempt to preserve the types of values stored in `vals` as -much as possible. Types can be promoted for type-stability, to maintain performance. The -returned buffer should be of the same type (ignoring type-parameters) as `oldbuffer`. - -This method is already implemented for -`remake_buffer(indp, oldbuffer::AbstractArray, vals::Dict)` and supports static arrays -as well. It is also implemented for `oldbuffer::Tuple`. + remake_buffer(indp, oldbuffer, idxs, vals) + +Return a copy of the buffer `oldbuffer` with at (optionally symbolic) indexes `idxs` +replaced by corresponding values from `vals`. Both `idxs` and `vals` must be iterables of +the same length. `idxs` may contain symbolic variables whose index in the buffer is +determined using `indp`. The types of values in `vals` may not match the types of values +stored at the corresponding indexes in the buffer, in which case the type of the buffer +should be promoted accordingly. In general, this method should attempt to preserve the +types of values stored in `vals` as much as possible. Types can be promoted for +type-stability, to maintain performance. The returned buffer should be of the same type +(ignoring type-parameters) as `oldbuffer`. + +This method is already implemented for `oldbuffer::AbstractArray` and `oldbuffer::Tuple`, +and supports static arrays as well. + +The deprecated version of this method which takes a `Dict` mapping symbols to values +instead of `idxs` and `vals` will dispatch to the new method. In addition if +no `remake_buffer` method exists with the new signature, it will call +`remake_buffer(sys, oldbuffer, Dict(idxs .=> vals))`. + +Note that the new method signature allows `idxs` to be indexes, instead of requiring +that they be symbolic variables. Thus, any type which implements the new method must +also support indexes in `idxs`. """ -function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict) +function remake_buffer(sys, oldbuffer::AbstractArray, idxs, vals) # similar when used with an `MArray` and nonconcrete eltype returns a # SizedArray. `similar_type` still returns an `MArray` if ArrayInterface.ismutable(oldbuffer) && !isa(oldbuffer, MArray) elT = Union{} - for val in values(vals) + for val in vals if val isa AbstractArray valT = eltype(val) else @@ -29,7 +39,8 @@ function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict) newbuffer = similar(oldbuffer, elT) copyto!(newbuffer, oldbuffer) - for (k, v) in vals + for (k, v) in zip(idxs, vals) + is_variable(sys, k) || is_parameter(sys, k) || continue if v isa AbstractArray v = elT.(v) else @@ -38,12 +49,16 @@ function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict) setu(sys, k)(newbuffer, v) end else - mutbuffer = remake_buffer(sys, collect(oldbuffer), vals) + mutbuffer = remake_buffer(sys, collect(oldbuffer), idxs, vals) newbuffer = similar_type(oldbuffer, eltype(mutbuffer))(mutbuffer) end return newbuffer end +function remake_buffer(sys, oldbuffer, idxs, vals) + remake_buffer(sys, oldbuffer, Dict(idxs .=> vals)) +end + mutable struct TupleRemakeWrapper t::Tuple end @@ -54,8 +69,19 @@ function set_parameter!(sys::TupleRemakeWrapper, val, idx) sys.t = tp end -function remake_buffer(sys, oldbuffer::Tuple, vals::Dict) +function set_state!(sys::TupleRemakeWrapper, val, idx) + tp = sys.t + @reset tp[idx] = val + sys.t = tp +end + +function remake_buffer(sys, oldbuffer::Tuple, idxs, vals) wrap = TupleRemakeWrapper(oldbuffer) - setu(sys, collect(keys(vals)))(wrap, values(vals)) + setu(sys, idxs)(wrap, vals) return wrap.t end + +@deprecate remake_buffer(sys, oldbuffer, vals::Dict) remake_buffer( + sys, oldbuffer, keys(vals), values(vals)) +@deprecate remake_buffer(sys, oldbuffer::Tuple, vals::Dict) remake_buffer( + sys, oldbuffer, collect(keys(vals)), collect(values(vals))) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 954b5e24..8657d46d 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -5,4 +5,4 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [compat] -SymbolicUtils = "<1.6" +SymbolicUtils = "3.2" diff --git a/test/downstream/remake_arrayvars.jl b/test/downstream/remake_arrayvars.jl index 2e8e6356..b80f1861 100644 --- a/test/downstream/remake_arrayvars.jl +++ b/test/downstream/remake_arrayvars.jl @@ -7,5 +7,5 @@ using SymbolicIndexingInterface sys = complete(sys) u0 = [1.0, 2.0, 3.0] -newu0 = remake_buffer(sys, u0, Dict(x => [5.0, 6.0], y => 7.0)) +newu0 = remake_buffer(sys, u0, [x, y], ([5.0, 6.0], 7.0)) @test newu0 == [5.0, 6.0, 7.0] diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 43d6f5dc..6c864c44 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -168,6 +168,18 @@ for sys in [ @test getter(fi) == [] getter = getp(sys, ()) @test getter(fi) == () + + for (sym, val) in [ + (:a, 1.0f1), + (1, 1.0f1), + ([:a, :b], [1.0f1, 2.0f1]), + ((:b, :c), (2.0f1, 3.0f1)) + ] + setter = setp_oop(fi, sym) + newp = setter(fi, val) + getter = getp(sys, sym) + @test getter(newp) == val + end end end diff --git a/test/remake_test.jl b/test/remake_test.jl index 0618e2e9..74c267eb 100644 --- a/test/remake_test.jl +++ b/test/remake_test.jl @@ -3,47 +3,60 @@ using StaticArrays sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) -for (buf, newbuf, newvals) in [ +for (buf, newbuf, idxs, vals) in [ # standard operation - ([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)), + ([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [:x, :y, :z], [2.0, 3.0, 4.0]), # buffer type "demotion" - ([1.0, 2.0, 3.0], [2, 2, 3], Dict(:x => 2)), + ([1.0, 2.0, 3.0], [2, 2, 3], [:x], [2]), # buffer type promotion - ([1, 2, 3], [2.0, 2.0, 3.0], Dict(:x => 2.0)), + ([1, 2, 3], [2.0, 2.0, 3.0], [:x], [2.0]), # value type promotion - ([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2, :y => 3.0, :z => 4.0)), + ([1, 2, 3], [2.0, 3.0, 4.0], [:x, :y, :z], Real[2, 3.0, 4.0]), # standard operation - ([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)), + ([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [:a, :b, :c], [2.0, 3.0, 4.0]), # buffer type "demotion" - ([1.0, 2.0, 3.0], [2, 2, 3], Dict(:a => 2)), + ([1.0, 2.0, 3.0], [2, 2, 3], [:a], [2]), # buffer type promotion - ([1, 2, 3], [2.0, 2.0, 3.0], Dict(:a => 2.0)), + ([1, 2, 3], [2.0, 2.0, 3.0], [:a], [2.0]), # value type promotion - ([1, 2, 3], [2, 3.0, 4.0], Dict(:a => 2, :b => 3.0, :c => 4.0)) + ([1, 2, 3], [2, 3.0, 4.0], [:a, :b, :c], Real[2, 3.0, 4.0]), + # skip non-parameters + ([1, 2, 3], [2.0, 3.0, 3.0], [:a, :b, :(a + b)], [2.0, 3.0, 5.0]) ] for arrType in [Vector, SVector{3}, MVector{3}, SizedVector{3}] buf = arrType(buf) newbuf = arrType(newbuf) - _newbuf = remake_buffer(sys, buf, newvals) + _newbuf = remake_buffer(sys, buf, idxs, vals) @test _newbuf != buf # should not alias @test newbuf == _newbuf # test values @test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type + @test_deprecated remake_buffer(sys, buf, Dict(idxs .=> vals)) end end -# Tuples not allowed for state -for (buf, newbuf, newvals) in [ +for (buf, newbuf, idxs, vals) in [ # standard operation - ((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)), + ((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), [:a, :b, :c], [2.0, 3.0, 4.0]), # buffer type "demotion" - ((1.0, 2.0, 3.0), (2, 3, 4), Dict(:a => 2, :b => 3, :c => 4)), + ((1.0, 2.0, 3.0), (2, 3, 4), [:a, :b, :c], [2, 3, 4]), # buffer type promotion - ((1, 2, 3), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)), + ((1, 2, 3), (2.0, 3.0, 4.0), [:a, :b, :c], [2.0, 3.0, 4.0]), # value type promotion - ((1, 2, 3), (2, 3.0, 4.0), Dict(:a => 2, :b => 3.0, :c => 4.0)) + ((1, 2, 3), (2, 3.0, 4.0), [:a, :b, :c], Real[2, 3.0, 4.0]), + # standard operation + ((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), [:x, :y, :z], [2.0, 3.0, 4.0]), + # buffer type "demotion" + ((1.0, 2.0, 3.0), (2, 3, 4), [:x, :y, :z], [2, 3, 4]), + # buffer type promotion + ((1, 2, 3), (2.0, 3.0, 4.0), [:x, :y, :z], [2.0, 3.0, 4.0]), + # value type promotion + ((1, 2, 3), (2, 3.0, 4.0), [:x, :y, :z], Real[2, 3.0, 4.0]), + # skip non-variables + ([1, 2, 3], [2.0, 3.0, 3.0], [:x, :y, :(x + y)], [2.0, 3.0, 5.0]) ] - _newbuf = remake_buffer(sys, buf, newvals) + _newbuf = remake_buffer(sys, buf, idxs, vals) @test newbuf == _newbuf # test values @test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type + @test_deprecated remake_buffer(sys, buf, Dict(idxs .=> vals)) end diff --git a/test/runtests.jl b/test/runtests.jl index b91706cf..363d9fbe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,4 +55,7 @@ if GROUP == "All" || GROUP == "Downstream" @safetestset "BatchedInterface with array symbolics test" begin @time include("downstream/batchedinterface_arrayvars.jl") end + @safetestset "remake_buffer with array symbolics test" begin + @time include("downstream/remake_arrayvars.jl") + end end