Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support indexing Tuple parameters, add tests #63

Merged
merged 6 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ authors = ["Aayush Sabharwal <[email protected]> and contributors"]
version = "0.3.13"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[compat]
Accessors = "0.1.36"
Aqua = "0.8"
ArrayInterface = "7.9"
MacroTools = "0.5.13"
Expand Down
6 changes: 4 additions & 2 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import MacroTools
using RuntimeGeneratedFunctions
import StaticArraysCore: MArray, similar_type
import ArrayInterface
using Accessors: @reset

RuntimeGeneratedFunctions.init(@__MODULE__)

Expand All @@ -22,8 +23,9 @@ include("interface.jl")
export SymbolCache
include("symbol_cache.jl")

export parameter_values, set_parameter!, parameter_values_at_time,
parameter_values_at_state_time, parameter_timeseries, getp, setp
export parameter_values, set_parameter!, finalize_parameters_hook!,
parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries, getp,
setp
include("parameter_indexing.jl")

export state_values, set_state!, current_time, getu, setu
Expand Down
8 changes: 6 additions & 2 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
two-argument version of this function will default to returning
`parameter_values(p)[i]`.

If this function is called with an `AbstractArray`, it will return the same array.
If this function is called with an `AbstractArray` or `Tuple`, it will return the same
array/tuple.
"""
function parameter_values end

parameter_values(arr::AbstractArray) = arr
parameter_values(arr::Tuple) = arr

Check warning on line 16 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L16

Added line #L16 was not covered by tests
parameter_values(arr::AbstractArray, i) = arr[i]
parameter_values(arr::Tuple, i) = arr[i]

Check warning on line 18 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L18

Added line #L18 was not covered by tests
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)

"""
Expand Down Expand Up @@ -77,7 +80,8 @@
"""
function set_parameter! end

function set_parameter!(sys::AbstractArray, val, idx)
# Tuple only included for the error message
function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx)

Check warning on line 84 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L84

Added line #L84 was not covered by tests
sys[idx] = val
end
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)
Expand Down
24 changes: 20 additions & 4 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
values in `vals` may not match the types of values stored at the corresponding indexes in
the buffer, in which case the type of the buffer should be promoted accordingly. In
general, this method should attempt to preserve the types of values stored in `vals` as
much as possible. The returned buffer should be of the same type (ignoring type-parameters)
as `oldbuffer`.
much as possible. Types can be promoted for type-stability, to maintain performance. The
returned buffer should be of the same type (ignoring type-parameters) as `oldbuffer`.

This method is already implemented for
`remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)` and supports static arrays
Expand All @@ -19,14 +19,30 @@
if ArrayInterface.ismutable(oldbuffer) && !isa(oldbuffer, MArray)
elT = Union{}
for val in values(vals)
elT = Union{elT, typeof(val)}
elT = promote_type(elT, typeof(val))

Check warning on line 22 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L22

Added line #L22 was not covered by tests
end

newbuffer = similar(oldbuffer, elT)
setu(sys, keys(vals))(newbuffer, values(vals))
setu(sys, collect(keys(vals)))(newbuffer, elT.(values(vals)))

Check warning on line 26 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L26

Added line #L26 was not covered by tests
else
mutbuffer = remake_buffer(sys, collect(oldbuffer), vals)
newbuffer = similar_type(oldbuffer, eltype(mutbuffer))(mutbuffer)
end
return newbuffer
end

mutable struct TupleRemakeWrapper
t::Tuple
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't infer, Tuple isn't concrete

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type of the tuple can change, which is why it's not fully typed

Comment on lines +34 to +35
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mutable struct TupleRemakeWrapper
t::Tuple
mutable struct TupleRemakeWrapper{T <: Tuple}
t::T

?

end

function set_parameter!(sys::TupleRemakeWrapper, val, idx)
tp = sys.t
@reset tp[idx] = val
sys.t = tp

Check warning on line 41 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L38-L41

Added lines #L38 - L41 were not covered by tests
end

function remake_buffer(sys, oldbuffer::Tuple, vals::Dict)
wrap = TupleRemakeWrapper(oldbuffer)
setu(sys, collect(keys(vals)))(wrap, values(vals))
return wrap.t

Check warning on line 47 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L44-L47

Added lines #L44 - L47 were not covered by tests
end
6 changes: 4 additions & 2 deletions src/trait.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"""
hasname(x)

Check whether the given symbolic variable (for which `symbolic_type(x) != NotSymbolic()`) has a valid name as per `getname`.
Check whether the given symbolic variable (for which `symbolic_type(x) != NotSymbolic()`) has a valid name as per `getname`. Defaults to `true` for `x::Symbol`.
"""
function hasname end

Expand All @@ -57,9 +57,11 @@
"""
getname(x)::Symbol

Get the name of a symbolic variable as a `Symbol`
Get the name of a symbolic variable as a `Symbol`. Acts as the identity function for
`x::Symbol`.
"""
function getname end
getname(x::Symbol) = x

Check warning on line 64 in src/trait.jl

View check run for this annotation

Codecov / codecov/patch

src/trait.jl#L64

Added line #L64 was not covered by tests

"""
symbolic_evaluate(expr, syms::Dict; kwargs...)
Expand Down
142 changes: 75 additions & 67 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,78 +17,86 @@ function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator,
end

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, copy(p), Ref(0))
new_p = [4.0, 5.0, 6.0]
@test parameter_timeseries(fi) == [0]
for (sym, oldval, newval, check_inference) in [
(:a, p[1], new_p[1], true),
(1, p[1], new_p[1], true),
([:a, :b], p[1:2], new_p[1:2], true),
(1:2, p[1:2], new_p[1:2], true),
((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true),
([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
]
get = getp(sys, sym)
set! = setp(sys, sym)
if check_inference
@inferred get(fi)
end
@test get(fi) == fi.ps[sym]
@test get(fi) == oldval
@test fi.counter[] == 0
if check_inference
@inferred set!(fi, newval)
else
set!(fi, newval)
end
@test fi.counter[] == 1
for pType in [Vector, Tuple]
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, pType(copy(p)), Ref(0))
new_p = [4.0, 5.0, 6.0]
@test parameter_timeseries(fi) == [0]
for (sym, oldval, newval, check_inference) in [
(:a, p[1], new_p[1], true),
(1, p[1], new_p[1], true),
([:a, :b], p[1:2], new_p[1:2], true),
(1:2, p[1:2], new_p[1:2], true),
((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true),
([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
]
get = getp(sys, sym)
set! = setp(sys, sym)
if check_inference
@inferred get(fi)
end
@test get(fi) == fi.ps[sym]
@test get(fi) == oldval

@test get(fi) == newval
set!(fi, oldval)
@test get(fi) == oldval
@test fi.counter[] == 2
if pType === Tuple
@test_throws MethodError set!(fi, newval)
continue
end

fi.ps[sym] = newval
@test get(fi) == newval
@test fi.counter[] == 3
fi.ps[sym] = oldval
@test get(fi) == oldval
@test fi.counter[] == 4
@test fi.counter[] == 0
if check_inference
@inferred set!(fi, newval)
else
set!(fi, newval)
end
@test fi.counter[] == 1

if check_inference
@inferred get(p)
end
@test get(p) == oldval
if check_inference
@inferred set!(p, newval)
else
set!(p, newval)
@test get(fi) == newval
set!(fi, oldval)
@test get(fi) == oldval
@test fi.counter[] == 2

fi.ps[sym] = newval
@test get(fi) == newval
@test fi.counter[] == 3
fi.ps[sym] = oldval
@test get(fi) == oldval
@test fi.counter[] == 4

if check_inference
@inferred get(p)
end
@test get(p) == oldval
if check_inference
@inferred set!(p, newval)
else
set!(p, newval)
end
@test get(p) == newval
set!(p, oldval)
@test get(p) == oldval
@test fi.counter[] == 4
fi.counter[] = 0
end
@test get(p) == newval
set!(p, oldval)
@test get(p) == oldval
@test fi.counter[] == 4
fi.counter[] = 0
end

for (sym, val) in [
([:a, :b, :c], p),
([:c, :a], p[[3, 1]]),
((:b, :a), p[[2, 1]]),
((1, :c), p[[1, 3]])
]
buffer = zeros(length(sym))
get = getp(sys, sym)
@inferred get(buffer, fi)
@test buffer == val
for (sym, val) in [
([:a, :b, :c], p),
([:c, :a], p[[3, 1]]),
((:b, :a), p[[2, 1]]),
((1, :c), p[[1, 3]])
]
buffer = zeros(length(sym))
get = getp(sys, sym)
@inferred get(buffer, fi)
@test buffer == val
end
end

struct FakeSolution
Expand Down
57 changes: 33 additions & 24 deletions test/remake_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,23 @@ using StaticArrays
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)

for (buf, newbuf, newvals) in [
# standard operation
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
Dict(:x => 2.0, :y => 3.0, :z => 4.0))
# type "demotion"
([1.0, 2.0, 3.0], [2, 3, 4],
Dict(:x => 2, :y => 3, :z => 4))
# type promotion
([1, 2, 3], [2.0, 3.0, 4.0],
Dict(:x => 2.0, :y => 3.0, :z => 4.0))
# union
([1, 2, 3], Union{Int, Float64}[2, 3.0, 4.0],
Dict(:x => 2, :y => 3.0, :z => 4.0))
# standard operation
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
Dict(:a => 2.0, :b => 3.0, :c => 4.0))
# type "demotion"
([1.0, 2.0, 3.0], [2, 3, 4],
Dict(:a => 2, :b => 3, :c => 4))
# type promotion
([1, 2, 3], [2.0, 3.0, 4.0],
Dict(:a => 2.0, :b => 3.0, :c => 4.0))
# union
([1, 2, 3], Union{Int, Float64}[2, 3.0, 4.0],
Dict(:a => 2, :b => 3.0, :c => 4.0))]
# standard operation
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)),
# buffer type "demotion"
([1.0, 2.0, 3.0], [2, 3, 4], Dict(:x => 2, :y => 3, :z => 4)),
# buffer type promotion
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)),
# value type promotion
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2, :y => 3.0, :z => 4.0)),
# standard operation
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
# buffer type "demotion"
([1.0, 2.0, 3.0], [2, 3, 4], Dict(:a => 2, :b => 3, :c => 4)),
# buffer type promotion
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
# value type promotion
([1, 2, 3], [2, 3.0, 4.0], Dict(:a => 2, :b => 3.0, :c => 4.0))
]
for arrType in [Vector, SVector{3}, MVector{3}, SizedVector{3}]
buf = arrType(buf)
newbuf = arrType(newbuf)
Expand All @@ -38,3 +31,19 @@ for (buf, newbuf, newvals) in [
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
end
end

# Tuples not allowed for state
for (buf, newbuf, newvals) in [
# standard operation
((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
# buffer type "demotion"
((1.0, 2.0, 3.0), (2, 3, 4), Dict(:a => 2, :b => 3, :c => 4)),
# buffer type promotion
((1, 2, 3), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
# value type promotion
((1, 2, 3), (2, 3.0, 4.0), Dict(:a => 2, :b => 3.0, :c => 4.0))
]
_newbuf = remake_buffer(sys, buf, newvals)
@test newbuf == _newbuf # test values
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
end
5 changes: 5 additions & 0 deletions test/trait_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@ using Test
@test all(symbolic_type.([Int, Float64, String, Bool, UInt, Complex{Float64}]) .==
(NotSymbolic(),))
@test symbolic_type(Symbol) == ScalarSymbolic()
@test hasname(:x)
@test getname(:x) == :x
@test !hasname(1)
@test !hasname(1.0)
@test !hasname("x")
Loading