Skip to content

Commit

Permalink
Merge pull request #63 from SciML/as/tuple-parameter-values
Browse files Browse the repository at this point in the history
feat: support indexing Tuple parameters, add tests
  • Loading branch information
ChrisRackauckas authored Mar 28, 2024
2 parents ed4bce0 + 30b4759 commit 124e387
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 101 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ authors = ["Aayush Sabharwal <[email protected]> and contributors"]
version = "0.3.13"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[compat]
Accessors = "0.1.36"
Aqua = "0.8"
ArrayInterface = "7.9"
MacroTools = "0.5.13"
Expand Down
6 changes: 4 additions & 2 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import MacroTools
using RuntimeGeneratedFunctions
import StaticArraysCore: MArray, similar_type
import ArrayInterface
using Accessors: @reset

RuntimeGeneratedFunctions.init(@__MODULE__)

Expand All @@ -22,8 +23,9 @@ include("interface.jl")
export SymbolCache
include("symbol_cache.jl")

export parameter_values, set_parameter!, parameter_values_at_time,
parameter_values_at_state_time, parameter_timeseries, getp, setp
export parameter_values, set_parameter!, finalize_parameters_hook!,
parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries, getp,
setp
include("parameter_indexing.jl")

export state_values, set_state!, current_time, getu, setu
Expand Down
8 changes: 6 additions & 2 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ argument version of this function returns the parameter value at index `i`. The
two-argument version of this function will default to returning
`parameter_values(p)[i]`.
If this function is called with an `AbstractArray`, it will return the same array.
If this function is called with an `AbstractArray` or `Tuple`, it will return the same
array/tuple.
"""
function parameter_values end

parameter_values(arr::AbstractArray) = arr
parameter_values(arr::Tuple) = arr
parameter_values(arr::AbstractArray, i) = arr[i]
parameter_values(arr::Tuple, i) = arr[i]
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)

"""
Expand Down Expand Up @@ -77,7 +80,8 @@ See: [`parameter_values`](@ref)
"""
function set_parameter! end

function set_parameter!(sys::AbstractArray, val, idx)
# Tuple only included for the error message
function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx)
sys[idx] = val
end
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)
Expand Down
24 changes: 20 additions & 4 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ are symbolic variables whose index in the buffer is determined using `sys`. The
values in `vals` may not match the types of values stored at the corresponding indexes in
the buffer, in which case the type of the buffer should be promoted accordingly. In
general, this method should attempt to preserve the types of values stored in `vals` as
much as possible. The returned buffer should be of the same type (ignoring type-parameters)
as `oldbuffer`.
much as possible. Types can be promoted for type-stability, to maintain performance. The
returned buffer should be of the same type (ignoring type-parameters) as `oldbuffer`.
This method is already implemented for
`remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)` and supports static arrays
Expand All @@ -19,14 +19,30 @@ function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)
if ArrayInterface.ismutable(oldbuffer) && !isa(oldbuffer, MArray)
elT = Union{}
for val in values(vals)
elT = Union{elT, typeof(val)}
elT = promote_type(elT, typeof(val))
end

newbuffer = similar(oldbuffer, elT)
setu(sys, keys(vals))(newbuffer, values(vals))
setu(sys, collect(keys(vals)))(newbuffer, elT.(values(vals)))
else
mutbuffer = remake_buffer(sys, collect(oldbuffer), vals)
newbuffer = similar_type(oldbuffer, eltype(mutbuffer))(mutbuffer)
end
return newbuffer
end

mutable struct TupleRemakeWrapper
t::Tuple
end

function set_parameter!(sys::TupleRemakeWrapper, val, idx)
tp = sys.t
@reset tp[idx] = val
sys.t = tp
end

function remake_buffer(sys, oldbuffer::Tuple, vals::Dict)
wrap = TupleRemakeWrapper(oldbuffer)
setu(sys, collect(keys(vals)))(wrap, values(vals))
return wrap.t
end
6 changes: 4 additions & 2 deletions src/trait.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ symbolic_type(::Type{Expr}) = ScalarSymbolic()
"""
hasname(x)
Check whether the given symbolic variable (for which `symbolic_type(x) != NotSymbolic()`) has a valid name as per `getname`.
Check whether the given symbolic variable (for which `symbolic_type(x) != NotSymbolic()`) has a valid name as per `getname`. Defaults to `true` for `x::Symbol`.
"""
function hasname end

Expand All @@ -57,9 +57,11 @@ hasname(::Any) = false
"""
getname(x)::Symbol
Get the name of a symbolic variable as a `Symbol`
Get the name of a symbolic variable as a `Symbol`. Acts as the identity function for
`x::Symbol`.
"""
function getname end
getname(x::Symbol) = x

"""
symbolic_evaluate(expr, syms::Dict; kwargs...)
Expand Down
142 changes: 75 additions & 67 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,78 +17,86 @@ function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator,
end

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, copy(p), Ref(0))
new_p = [4.0, 5.0, 6.0]
@test parameter_timeseries(fi) == [0]
for (sym, oldval, newval, check_inference) in [
(:a, p[1], new_p[1], true),
(1, p[1], new_p[1], true),
([:a, :b], p[1:2], new_p[1:2], true),
(1:2, p[1:2], new_p[1:2], true),
((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true),
([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
]
get = getp(sys, sym)
set! = setp(sys, sym)
if check_inference
@inferred get(fi)
end
@test get(fi) == fi.ps[sym]
@test get(fi) == oldval
@test fi.counter[] == 0
if check_inference
@inferred set!(fi, newval)
else
set!(fi, newval)
end
@test fi.counter[] == 1
for pType in [Vector, Tuple]
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, pType(copy(p)), Ref(0))
new_p = [4.0, 5.0, 6.0]
@test parameter_timeseries(fi) == [0]
for (sym, oldval, newval, check_inference) in [
(:a, p[1], new_p[1], true),
(1, p[1], new_p[1], true),
([:a, :b], p[1:2], new_p[1:2], true),
(1:2, p[1:2], new_p[1:2], true),
((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true),
([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
]
get = getp(sys, sym)
set! = setp(sys, sym)
if check_inference
@inferred get(fi)
end
@test get(fi) == fi.ps[sym]
@test get(fi) == oldval

@test get(fi) == newval
set!(fi, oldval)
@test get(fi) == oldval
@test fi.counter[] == 2
if pType === Tuple
@test_throws MethodError set!(fi, newval)
continue
end

fi.ps[sym] = newval
@test get(fi) == newval
@test fi.counter[] == 3
fi.ps[sym] = oldval
@test get(fi) == oldval
@test fi.counter[] == 4
@test fi.counter[] == 0
if check_inference
@inferred set!(fi, newval)
else
set!(fi, newval)
end
@test fi.counter[] == 1

if check_inference
@inferred get(p)
end
@test get(p) == oldval
if check_inference
@inferred set!(p, newval)
else
set!(p, newval)
@test get(fi) == newval
set!(fi, oldval)
@test get(fi) == oldval
@test fi.counter[] == 2

fi.ps[sym] = newval
@test get(fi) == newval
@test fi.counter[] == 3
fi.ps[sym] = oldval
@test get(fi) == oldval
@test fi.counter[] == 4

if check_inference
@inferred get(p)
end
@test get(p) == oldval
if check_inference
@inferred set!(p, newval)
else
set!(p, newval)
end
@test get(p) == newval
set!(p, oldval)
@test get(p) == oldval
@test fi.counter[] == 4
fi.counter[] = 0
end
@test get(p) == newval
set!(p, oldval)
@test get(p) == oldval
@test fi.counter[] == 4
fi.counter[] = 0
end

for (sym, val) in [
([:a, :b, :c], p),
([:c, :a], p[[3, 1]]),
((:b, :a), p[[2, 1]]),
((1, :c), p[[1, 3]])
]
buffer = zeros(length(sym))
get = getp(sys, sym)
@inferred get(buffer, fi)
@test buffer == val
for (sym, val) in [
([:a, :b, :c], p),
([:c, :a], p[[3, 1]]),
((:b, :a), p[[2, 1]]),
((1, :c), p[[1, 3]])
]
buffer = zeros(length(sym))
get = getp(sys, sym)
@inferred get(buffer, fi)
@test buffer == val
end
end

struct FakeSolution
Expand Down
57 changes: 33 additions & 24 deletions test/remake_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,23 @@ using StaticArrays
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)

for (buf, newbuf, newvals) in [
# standard operation
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
Dict(:x => 2.0, :y => 3.0, :z => 4.0))
# type "demotion"
([1.0, 2.0, 3.0], [2, 3, 4],
Dict(:x => 2, :y => 3, :z => 4))
# type promotion
([1, 2, 3], [2.0, 3.0, 4.0],
Dict(:x => 2.0, :y => 3.0, :z => 4.0))
# union
([1, 2, 3], Union{Int, Float64}[2, 3.0, 4.0],
Dict(:x => 2, :y => 3.0, :z => 4.0))
# standard operation
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
Dict(:a => 2.0, :b => 3.0, :c => 4.0))
# type "demotion"
([1.0, 2.0, 3.0], [2, 3, 4],
Dict(:a => 2, :b => 3, :c => 4))
# type promotion
([1, 2, 3], [2.0, 3.0, 4.0],
Dict(:a => 2.0, :b => 3.0, :c => 4.0))
# union
([1, 2, 3], Union{Int, Float64}[2, 3.0, 4.0],
Dict(:a => 2, :b => 3.0, :c => 4.0))]
# standard operation
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)),
# buffer type "demotion"
([1.0, 2.0, 3.0], [2, 3, 4], Dict(:x => 2, :y => 3, :z => 4)),
# buffer type promotion
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)),
# value type promotion
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2, :y => 3.0, :z => 4.0)),
# standard operation
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
# buffer type "demotion"
([1.0, 2.0, 3.0], [2, 3, 4], Dict(:a => 2, :b => 3, :c => 4)),
# buffer type promotion
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
# value type promotion
([1, 2, 3], [2, 3.0, 4.0], Dict(:a => 2, :b => 3.0, :c => 4.0))
]
for arrType in [Vector, SVector{3}, MVector{3}, SizedVector{3}]
buf = arrType(buf)
newbuf = arrType(newbuf)
Expand All @@ -38,3 +31,19 @@ for (buf, newbuf, newvals) in [
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
end
end

# Tuples not allowed for state
for (buf, newbuf, newvals) in [
# standard operation
((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
# buffer type "demotion"
((1.0, 2.0, 3.0), (2, 3, 4), Dict(:a => 2, :b => 3, :c => 4)),
# buffer type promotion
((1, 2, 3), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
# value type promotion
((1, 2, 3), (2, 3.0, 4.0), Dict(:a => 2, :b => 3.0, :c => 4.0))
]
_newbuf = remake_buffer(sys, buf, newvals)
@test newbuf == _newbuf # test values
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
end
5 changes: 5 additions & 0 deletions test/trait_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@ using Test
@test all(symbolic_type.([Int, Float64, String, Bool, UInt, Complex{Float64}]) .==
(NotSymbolic(),))
@test symbolic_type(Symbol) == ScalarSymbolic()
@test hasname(:x)
@test getname(:x) == :x
@test !hasname(1)
@test !hasname(1.0)
@test !hasname("x")

0 comments on commit 124e387

Please sign in to comment.