diff --git a/Project.toml b/Project.toml index 9cfdd5b..d5b5b5f 100644 --- a/Project.toml +++ b/Project.toml @@ -4,12 +4,14 @@ authors = ["Aayush Sabharwal and contributors"] version = "0.3.13" [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [compat] +Accessors = "0.1.36" Aqua = "0.8" ArrayInterface = "7.9" MacroTools = "0.5.13" diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 4f7eb04..af90a25 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -4,6 +4,7 @@ import MacroTools using RuntimeGeneratedFunctions import StaticArraysCore: MArray, similar_type import ArrayInterface +using Accessors: @reset RuntimeGeneratedFunctions.init(@__MODULE__) @@ -22,8 +23,9 @@ include("interface.jl") export SymbolCache include("symbol_cache.jl") -export parameter_values, set_parameter!, parameter_values_at_time, - parameter_values_at_state_time, parameter_timeseries, getp, setp +export parameter_values, set_parameter!, finalize_parameters_hook!, + parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries, getp, + setp include("parameter_indexing.jl") export state_values, set_state!, current_time, getu, setu diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 6cbab3f..d7b73cf 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -7,12 +7,15 @@ argument version of this function returns the parameter value at index `i`. The two-argument version of this function will default to returning `parameter_values(p)[i]`. -If this function is called with an `AbstractArray`, it will return the same array. +If this function is called with an `AbstractArray` or `Tuple`, it will return the same +array/tuple. """ function parameter_values end parameter_values(arr::AbstractArray) = arr +parameter_values(arr::Tuple) = arr parameter_values(arr::AbstractArray, i) = arr[i] +parameter_values(arr::Tuple, i) = arr[i] parameter_values(prob, i) = parameter_values(parameter_values(prob), i) """ @@ -77,7 +80,8 @@ See: [`parameter_values`](@ref) """ function set_parameter! end -function set_parameter!(sys::AbstractArray, val, idx) +# Tuple only included for the error message +function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx) sys[idx] = val end set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx) diff --git a/src/remake.jl b/src/remake.jl index f53b2b2..2dbfe08 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -6,8 +6,8 @@ are symbolic variables whose index in the buffer is determined using `sys`. The values in `vals` may not match the types of values stored at the corresponding indexes in the buffer, in which case the type of the buffer should be promoted accordingly. In general, this method should attempt to preserve the types of values stored in `vals` as -much as possible. The returned buffer should be of the same type (ignoring type-parameters) -as `oldbuffer`. +much as possible. Types can be promoted for type-stability, to maintain performance. The +returned buffer should be of the same type (ignoring type-parameters) as `oldbuffer`. This method is already implemented for `remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)` and supports static arrays @@ -19,14 +19,30 @@ function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict) if ArrayInterface.ismutable(oldbuffer) && !isa(oldbuffer, MArray) elT = Union{} for val in values(vals) - elT = Union{elT, typeof(val)} + elT = promote_type(elT, typeof(val)) end newbuffer = similar(oldbuffer, elT) - setu(sys, keys(vals))(newbuffer, values(vals)) + setu(sys, collect(keys(vals)))(newbuffer, elT.(values(vals))) else mutbuffer = remake_buffer(sys, collect(oldbuffer), vals) newbuffer = similar_type(oldbuffer, eltype(mutbuffer))(mutbuffer) end return newbuffer end + +mutable struct TupleRemakeWrapper + t::Tuple +end + +function set_parameter!(sys::TupleRemakeWrapper, val, idx) + tp = sys.t + @reset tp[idx] = val + sys.t = tp +end + +function remake_buffer(sys, oldbuffer::Tuple, vals::Dict) + wrap = TupleRemakeWrapper(oldbuffer) + setu(sys, collect(keys(vals)))(wrap, values(vals)) + return wrap.t +end diff --git a/src/trait.jl b/src/trait.jl index 6d742de..ea7964a 100644 --- a/src/trait.jl +++ b/src/trait.jl @@ -47,7 +47,7 @@ symbolic_type(::Type{Expr}) = ScalarSymbolic() """ hasname(x) -Check whether the given symbolic variable (for which `symbolic_type(x) != NotSymbolic()`) has a valid name as per `getname`. +Check whether the given symbolic variable (for which `symbolic_type(x) != NotSymbolic()`) has a valid name as per `getname`. Defaults to `true` for `x::Symbol`. """ function hasname end @@ -57,9 +57,11 @@ hasname(::Any) = false """ getname(x)::Symbol -Get the name of a symbolic variable as a `Symbol` +Get the name of a symbolic variable as a `Symbol`. Acts as the identity function for +`x::Symbol`. """ function getname end +getname(x::Symbol) = x """ symbolic_evaluate(expr, syms::Dict; kwargs...) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 7684659..42dd950 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -17,78 +17,86 @@ function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator, end sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) -p = [1.0, 2.0, 3.0] -fi = FakeIntegrator(sys, copy(p), Ref(0)) -new_p = [4.0, 5.0, 6.0] -@test parameter_timeseries(fi) == [0] -for (sym, oldval, newval, check_inference) in [ - (:a, p[1], new_p[1], true), - (1, p[1], new_p[1], true), - ([:a, :b], p[1:2], new_p[1:2], true), - (1:2, p[1:2], new_p[1:2], true), - ((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true), - ([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), - ([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), - ((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), - ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), - ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), - ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true) -] - get = getp(sys, sym) - set! = setp(sys, sym) - if check_inference - @inferred get(fi) - end - @test get(fi) == fi.ps[sym] - @test get(fi) == oldval - @test fi.counter[] == 0 - if check_inference - @inferred set!(fi, newval) - else - set!(fi, newval) - end - @test fi.counter[] == 1 +for pType in [Vector, Tuple] + p = [1.0, 2.0, 3.0] + fi = FakeIntegrator(sys, pType(copy(p)), Ref(0)) + new_p = [4.0, 5.0, 6.0] + @test parameter_timeseries(fi) == [0] + for (sym, oldval, newval, check_inference) in [ + (:a, p[1], new_p[1], true), + (1, p[1], new_p[1], true), + ([:a, :b], p[1:2], new_p[1:2], true), + (1:2, p[1:2], new_p[1:2], true), + ((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true), + ([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), + ((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), + ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), + ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true) + ] + get = getp(sys, sym) + set! = setp(sys, sym) + if check_inference + @inferred get(fi) + end + @test get(fi) == fi.ps[sym] + @test get(fi) == oldval - @test get(fi) == newval - set!(fi, oldval) - @test get(fi) == oldval - @test fi.counter[] == 2 + if pType === Tuple + @test_throws MethodError set!(fi, newval) + continue + end - fi.ps[sym] = newval - @test get(fi) == newval - @test fi.counter[] == 3 - fi.ps[sym] = oldval - @test get(fi) == oldval - @test fi.counter[] == 4 + @test fi.counter[] == 0 + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test fi.counter[] == 1 - if check_inference - @inferred get(p) - end - @test get(p) == oldval - if check_inference - @inferred set!(p, newval) - else - set!(p, newval) + @test get(fi) == newval + set!(fi, oldval) + @test get(fi) == oldval + @test fi.counter[] == 2 + + fi.ps[sym] = newval + @test get(fi) == newval + @test fi.counter[] == 3 + fi.ps[sym] = oldval + @test get(fi) == oldval + @test fi.counter[] == 4 + + if check_inference + @inferred get(p) + end + @test get(p) == oldval + if check_inference + @inferred set!(p, newval) + else + set!(p, newval) + end + @test get(p) == newval + set!(p, oldval) + @test get(p) == oldval + @test fi.counter[] == 4 + fi.counter[] = 0 end - @test get(p) == newval - set!(p, oldval) - @test get(p) == oldval - @test fi.counter[] == 4 - fi.counter[] = 0 -end -for (sym, val) in [ - ([:a, :b, :c], p), - ([:c, :a], p[[3, 1]]), - ((:b, :a), p[[2, 1]]), - ((1, :c), p[[1, 3]]) -] - buffer = zeros(length(sym)) - get = getp(sys, sym) - @inferred get(buffer, fi) - @test buffer == val + for (sym, val) in [ + ([:a, :b, :c], p), + ([:c, :a], p[[3, 1]]), + ((:b, :a), p[[2, 1]]), + ((1, :c), p[[1, 3]]) + ] + buffer = zeros(length(sym)) + get = getp(sys, sym) + @inferred get(buffer, fi) + @test buffer == val + end end struct FakeSolution diff --git a/test/remake_test.jl b/test/remake_test.jl index a37826e..9f4bad1 100644 --- a/test/remake_test.jl +++ b/test/remake_test.jl @@ -4,30 +4,23 @@ using StaticArrays sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) for (buf, newbuf, newvals) in [ - # standard operation - ([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], - Dict(:x => 2.0, :y => 3.0, :z => 4.0)) - # type "demotion" - ([1.0, 2.0, 3.0], [2, 3, 4], - Dict(:x => 2, :y => 3, :z => 4)) - # type promotion - ([1, 2, 3], [2.0, 3.0, 4.0], - Dict(:x => 2.0, :y => 3.0, :z => 4.0)) - # union - ([1, 2, 3], Union{Int, Float64}[2, 3.0, 4.0], - Dict(:x => 2, :y => 3.0, :z => 4.0)) - # standard operation - ([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], - Dict(:a => 2.0, :b => 3.0, :c => 4.0)) - # type "demotion" - ([1.0, 2.0, 3.0], [2, 3, 4], - Dict(:a => 2, :b => 3, :c => 4)) - # type promotion - ([1, 2, 3], [2.0, 3.0, 4.0], - Dict(:a => 2.0, :b => 3.0, :c => 4.0)) - # union - ([1, 2, 3], Union{Int, Float64}[2, 3.0, 4.0], - Dict(:a => 2, :b => 3.0, :c => 4.0))] + # standard operation + ([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)), + # buffer type "demotion" + ([1.0, 2.0, 3.0], [2, 3, 4], Dict(:x => 2, :y => 3, :z => 4)), + # buffer type promotion + ([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)), + # value type promotion + ([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2, :y => 3.0, :z => 4.0)), + # standard operation + ([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)), + # buffer type "demotion" + ([1.0, 2.0, 3.0], [2, 3, 4], Dict(:a => 2, :b => 3, :c => 4)), + # buffer type promotion + ([1, 2, 3], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)), + # value type promotion + ([1, 2, 3], [2, 3.0, 4.0], Dict(:a => 2, :b => 3.0, :c => 4.0)) +] for arrType in [Vector, SVector{3}, MVector{3}, SizedVector{3}] buf = arrType(buf) newbuf = arrType(newbuf) @@ -38,3 +31,19 @@ for (buf, newbuf, newvals) in [ @test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type end end + +# Tuples not allowed for state +for (buf, newbuf, newvals) in [ + # standard operation + ((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)), + # buffer type "demotion" + ((1.0, 2.0, 3.0), (2, 3, 4), Dict(:a => 2, :b => 3, :c => 4)), + # buffer type promotion + ((1, 2, 3), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)), + # value type promotion + ((1, 2, 3), (2, 3.0, 4.0), Dict(:a => 2, :b => 3.0, :c => 4.0)) +] + _newbuf = remake_buffer(sys, buf, newvals) + @test newbuf == _newbuf # test values + @test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type +end diff --git a/test/trait_test.jl b/test/trait_test.jl index de7eb85..74de5d8 100644 --- a/test/trait_test.jl +++ b/test/trait_test.jl @@ -4,3 +4,8 @@ using Test @test all(symbolic_type.([Int, Float64, String, Bool, UInt, Complex{Float64}]) .== (NotSymbolic(),)) @test symbolic_type(Symbol) == ScalarSymbolic() +@test hasname(:x) +@test getname(:x) == :x +@test !hasname(1) +@test !hasname(1.0) +@test !hasname("x")