Skip to content

Commit

Permalink
Merge pull request #92 from SciML/as/oop-setp
Browse files Browse the repository at this point in the history
feat: add out-of-place `setp`, refactor `remake_buffer`
  • Loading branch information
ChrisRackauckas authored Sep 4, 2024
2 parents 82f1464 + 5a3ac25 commit 8179344
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 40 deletions.
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ set_parameter!
finalize_parameters_hook!
getp
setp
setp_oop
ParameterIndexingProxy
```

Expand Down
2 changes: 1 addition & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ include("value_provider_interface.jl")
export ParameterTimeseriesCollection
include("parameter_timeseries_collection.jl")

export getp, setp
export getp, setp, setp_oop
include("parameter_indexing.jl")

export getu, setu
Expand Down
68 changes: 67 additions & 1 deletion src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ end
"""
setp(indp, sym)
Return a function that takes an index provider and a value, and sets the parameter `sym`
Return a function that takes a value provider and a value, and sets the parameter `sym`
to that value. Note that `sym` can be an index, a symbolic variable, or an array/tuple of
the aforementioned.
Expand Down Expand Up @@ -709,3 +709,69 @@ function _setp(sys, ::ArraySymbolic, ::SymbolicTypeTrait, p)
end
return setp(sys, collect(p); run_hook = false)
end

"""
setp_oop(indp, sym)
Return a function which takes a value provider `valp` and a value `val`, and returns
`parameter_values(valp)` with the parameters at `sym` set to `val`. This allows changing
the types of values stored, and leverages [`remake_buffer`](@ref). Note that `sym` can be
an index, a symbolic variable, or an array/tuple of the aforementioned.
Requires that the value provider implement `parameter_values` and `remake_buffer`.
"""
function setp_oop(indp, sym)
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))
return _setp_oop(indp, symtype, elsymtype, sym)
end

struct OOPSetter{I, D}
indp::I
idxs::D
end

function (os::OOPSetter)(valp, val)
return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,))
end

function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray})
if os.idxs isa Union{Tuple, AbstractArray}
return remake_buffer(os.indp, parameter_values(valp), os.idxs, val)
else
return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,))
end
end

function _root_indp(indp)
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
(sc = symbolic_container(indp)) != indp
return _root_indp(sc)
else
return indp
end
end

function _setp_oop(indp, ::NotSymbolic, ::NotSymbolic, sym)
return OOPSetter(_root_indp(indp), sym)
end

function _setp_oop(indp, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
return OOPSetter(_root_indp(indp), parameter_index(indp, sym))
end

for (t1, t2) in [
(ScalarSymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _setp_oop(indp, ::NotSymbolic, ::$t1, sym::$t2)
return OOPSetter(_root_indp(indp), parameter_index.((indp,), sym))
end
end

function _setp_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym)
if is_parameter(indp, sym)
return OOPSetter(_root_indp(indp), parameter_index(indp, sym))
end
error("$sym is not a valid parameter")
end
64 changes: 45 additions & 19 deletions src/remake.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
"""
remake_buffer(indp, oldbuffer, vals::Dict)
Return a copy of the buffer `oldbuffer` with values from `vals`. The keys of `vals`
are symbolic variables whose index in the buffer is determined using `indp`. The types of
values in `vals` may not match the types of values stored at the corresponding indexes in
the buffer, in which case the type of the buffer should be promoted accordingly. In
general, this method should attempt to preserve the types of values stored in `vals` as
much as possible. Types can be promoted for type-stability, to maintain performance. The
returned buffer should be of the same type (ignoring type-parameters) as `oldbuffer`.
This method is already implemented for
`remake_buffer(indp, oldbuffer::AbstractArray, vals::Dict)` and supports static arrays
as well. It is also implemented for `oldbuffer::Tuple`.
remake_buffer(indp, oldbuffer, idxs, vals)
Return a copy of the buffer `oldbuffer` with at (optionally symbolic) indexes `idxs`
replaced by corresponding values from `vals`. Both `idxs` and `vals` must be iterables of
the same length. `idxs` may contain symbolic variables whose index in the buffer is
determined using `indp`. The types of values in `vals` may not match the types of values
stored at the corresponding indexes in the buffer, in which case the type of the buffer
should be promoted accordingly. In general, this method should attempt to preserve the
types of values stored in `vals` as much as possible. Types can be promoted for
type-stability, to maintain performance. The returned buffer should be of the same type
(ignoring type-parameters) as `oldbuffer`.
This method is already implemented for `oldbuffer::AbstractArray` and `oldbuffer::Tuple`,
and supports static arrays as well.
The deprecated version of this method which takes a `Dict` mapping symbols to values
instead of `idxs` and `vals` will dispatch to the new method. In addition if
no `remake_buffer` method exists with the new signature, it will call
`remake_buffer(sys, oldbuffer, Dict(idxs .=> vals))`.
Note that the new method signature allows `idxs` to be indexes, instead of requiring
that they be symbolic variables. Thus, any type which implements the new method must
also support indexes in `idxs`.
"""
function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)
function remake_buffer(sys, oldbuffer::AbstractArray, idxs, vals)
# similar when used with an `MArray` and nonconcrete eltype returns a
# SizedArray. `similar_type` still returns an `MArray`
if ArrayInterface.ismutable(oldbuffer) && !isa(oldbuffer, MArray)
elT = Union{}
for val in values(vals)
for val in vals
if val isa AbstractArray
valT = eltype(val)
else
Expand All @@ -29,7 +39,8 @@ function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)

newbuffer = similar(oldbuffer, elT)
copyto!(newbuffer, oldbuffer)
for (k, v) in vals
for (k, v) in zip(idxs, vals)
is_variable(sys, k) || is_parameter(sys, k) || continue
if v isa AbstractArray
v = elT.(v)
else
Expand All @@ -38,12 +49,16 @@ function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)
setu(sys, k)(newbuffer, v)
end
else
mutbuffer = remake_buffer(sys, collect(oldbuffer), vals)
mutbuffer = remake_buffer(sys, collect(oldbuffer), idxs, vals)
newbuffer = similar_type(oldbuffer, eltype(mutbuffer))(mutbuffer)
end
return newbuffer
end

function remake_buffer(sys, oldbuffer, idxs, vals)
remake_buffer(sys, oldbuffer, Dict(idxs .=> vals))
end

mutable struct TupleRemakeWrapper
t::Tuple
end
Expand All @@ -54,8 +69,19 @@ function set_parameter!(sys::TupleRemakeWrapper, val, idx)
sys.t = tp
end

function remake_buffer(sys, oldbuffer::Tuple, vals::Dict)
function set_state!(sys::TupleRemakeWrapper, val, idx)
tp = sys.t
@reset tp[idx] = val
sys.t = tp
end

function remake_buffer(sys, oldbuffer::Tuple, idxs, vals)
wrap = TupleRemakeWrapper(oldbuffer)
setu(sys, collect(keys(vals)))(wrap, values(vals))
setu(sys, idxs)(wrap, vals)
return wrap.t
end

@deprecate remake_buffer(sys, oldbuffer, vals::Dict) remake_buffer(
sys, oldbuffer, keys(vals), values(vals))
@deprecate remake_buffer(sys, oldbuffer::Tuple, vals::Dict) remake_buffer(
sys, oldbuffer, collect(keys(vals)), collect(values(vals)))
2 changes: 1 addition & 1 deletion test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[compat]
SymbolicUtils = "<1.6"
SymbolicUtils = "3.2"
2 changes: 1 addition & 1 deletion test/downstream/remake_arrayvars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ using SymbolicIndexingInterface
sys = complete(sys)

u0 = [1.0, 2.0, 3.0]
newu0 = remake_buffer(sys, u0, Dict(x => [5.0, 6.0], y => 7.0))
newu0 = remake_buffer(sys, u0, [x, y], ([5.0, 6.0], 7.0))
@test newu0 == [5.0, 6.0, 7.0]
12 changes: 12 additions & 0 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ for sys in [
@test getter(fi) == []
getter = getp(sys, ())
@test getter(fi) == ()

for (sym, val) in [
(:a, 1.0f1),
(1, 1.0f1),
([:a, :b], [1.0f1, 2.0f1]),
((:b, :c), (2.0f1, 3.0f1))
]
setter = setp_oop(fi, sym)
newp = setter(fi, val)
getter = getp(sys, sym)
@test getter(newp) == val
end
end
end

Expand Down
47 changes: 30 additions & 17 deletions test/remake_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,60 @@ using StaticArrays

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)

for (buf, newbuf, newvals) in [
for (buf, newbuf, idxs, vals) in [
# standard operation
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)),
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [:x, :y, :z], [2.0, 3.0, 4.0]),
# buffer type "demotion"
([1.0, 2.0, 3.0], [2, 2, 3], Dict(:x => 2)),
([1.0, 2.0, 3.0], [2, 2, 3], [:x], [2]),
# buffer type promotion
([1, 2, 3], [2.0, 2.0, 3.0], Dict(:x => 2.0)),
([1, 2, 3], [2.0, 2.0, 3.0], [:x], [2.0]),
# value type promotion
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2, :y => 3.0, :z => 4.0)),
([1, 2, 3], [2.0, 3.0, 4.0], [:x, :y, :z], Real[2, 3.0, 4.0]),
# standard operation
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [:a, :b, :c], [2.0, 3.0, 4.0]),
# buffer type "demotion"
([1.0, 2.0, 3.0], [2, 2, 3], Dict(:a => 2)),
([1.0, 2.0, 3.0], [2, 2, 3], [:a], [2]),
# buffer type promotion
([1, 2, 3], [2.0, 2.0, 3.0], Dict(:a => 2.0)),
([1, 2, 3], [2.0, 2.0, 3.0], [:a], [2.0]),
# value type promotion
([1, 2, 3], [2, 3.0, 4.0], Dict(:a => 2, :b => 3.0, :c => 4.0))
([1, 2, 3], [2, 3.0, 4.0], [:a, :b, :c], Real[2, 3.0, 4.0]),
# skip non-parameters
([1, 2, 3], [2.0, 3.0, 3.0], [:a, :b, :(a + b)], [2.0, 3.0, 5.0])
]
for arrType in [Vector, SVector{3}, MVector{3}, SizedVector{3}]
buf = arrType(buf)
newbuf = arrType(newbuf)
_newbuf = remake_buffer(sys, buf, newvals)
_newbuf = remake_buffer(sys, buf, idxs, vals)

@test _newbuf != buf # should not alias
@test newbuf == _newbuf # test values
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
@test_deprecated remake_buffer(sys, buf, Dict(idxs .=> vals))
end
end

# Tuples not allowed for state
for (buf, newbuf, newvals) in [
for (buf, newbuf, idxs, vals) in [
# standard operation
((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), [:a, :b, :c], [2.0, 3.0, 4.0]),
# buffer type "demotion"
((1.0, 2.0, 3.0), (2, 3, 4), Dict(:a => 2, :b => 3, :c => 4)),
((1.0, 2.0, 3.0), (2, 3, 4), [:a, :b, :c], [2, 3, 4]),
# buffer type promotion
((1, 2, 3), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
((1, 2, 3), (2.0, 3.0, 4.0), [:a, :b, :c], [2.0, 3.0, 4.0]),
# value type promotion
((1, 2, 3), (2, 3.0, 4.0), Dict(:a => 2, :b => 3.0, :c => 4.0))
((1, 2, 3), (2, 3.0, 4.0), [:a, :b, :c], Real[2, 3.0, 4.0]),
# standard operation
((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), [:x, :y, :z], [2.0, 3.0, 4.0]),
# buffer type "demotion"
((1.0, 2.0, 3.0), (2, 3, 4), [:x, :y, :z], [2, 3, 4]),
# buffer type promotion
((1, 2, 3), (2.0, 3.0, 4.0), [:x, :y, :z], [2.0, 3.0, 4.0]),
# value type promotion
((1, 2, 3), (2, 3.0, 4.0), [:x, :y, :z], Real[2, 3.0, 4.0]),
# skip non-variables
([1, 2, 3], [2.0, 3.0, 3.0], [:x, :y, :(x + y)], [2.0, 3.0, 5.0])
]
_newbuf = remake_buffer(sys, buf, newvals)
_newbuf = remake_buffer(sys, buf, idxs, vals)
@test newbuf == _newbuf # test values
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
@test_deprecated remake_buffer(sys, buf, Dict(idxs .=> vals))
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,7 @@ if GROUP == "All" || GROUP == "Downstream"
@safetestset "BatchedInterface with array symbolics test" begin
@time include("downstream/batchedinterface_arrayvars.jl")
end
@safetestset "remake_buffer with array symbolics test" begin
@time include("downstream/remake_arrayvars.jl")
end
end

0 comments on commit 8179344

Please sign in to comment.