diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 6cbab3f1..d7b73cf7 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -7,12 +7,15 @@ argument version of this function returns the parameter value at index `i`. The two-argument version of this function will default to returning `parameter_values(p)[i]`. -If this function is called with an `AbstractArray`, it will return the same array. +If this function is called with an `AbstractArray` or `Tuple`, it will return the same +array/tuple. """ function parameter_values end parameter_values(arr::AbstractArray) = arr +parameter_values(arr::Tuple) = arr parameter_values(arr::AbstractArray, i) = arr[i] +parameter_values(arr::Tuple, i) = arr[i] parameter_values(prob, i) = parameter_values(parameter_values(prob), i) """ @@ -77,7 +80,8 @@ See: [`parameter_values`](@ref) """ function set_parameter! end -function set_parameter!(sys::AbstractArray, val, idx) +# Tuple only included for the error message +function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx) sys[idx] = val end set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 76846598..42dd9500 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -17,78 +17,86 @@ function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator, end sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) -p = [1.0, 2.0, 3.0] -fi = FakeIntegrator(sys, copy(p), Ref(0)) -new_p = [4.0, 5.0, 6.0] -@test parameter_timeseries(fi) == [0] -for (sym, oldval, newval, check_inference) in [ - (:a, p[1], new_p[1], true), - (1, p[1], new_p[1], true), - ([:a, :b], p[1:2], new_p[1:2], true), - (1:2, p[1:2], new_p[1:2], true), - ((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true), - ([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), - ([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), - ((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), - ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), - ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), - ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true) -] - get = getp(sys, sym) - set! = setp(sys, sym) - if check_inference - @inferred get(fi) - end - @test get(fi) == fi.ps[sym] - @test get(fi) == oldval - @test fi.counter[] == 0 - if check_inference - @inferred set!(fi, newval) - else - set!(fi, newval) - end - @test fi.counter[] == 1 +for pType in [Vector, Tuple] + p = [1.0, 2.0, 3.0] + fi = FakeIntegrator(sys, pType(copy(p)), Ref(0)) + new_p = [4.0, 5.0, 6.0] + @test parameter_timeseries(fi) == [0] + for (sym, oldval, newval, check_inference) in [ + (:a, p[1], new_p[1], true), + (1, p[1], new_p[1], true), + ([:a, :b], p[1:2], new_p[1:2], true), + (1:2, p[1:2], new_p[1:2], true), + ((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true), + ([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), + ((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), + ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), + ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true) + ] + get = getp(sys, sym) + set! = setp(sys, sym) + if check_inference + @inferred get(fi) + end + @test get(fi) == fi.ps[sym] + @test get(fi) == oldval - @test get(fi) == newval - set!(fi, oldval) - @test get(fi) == oldval - @test fi.counter[] == 2 + if pType === Tuple + @test_throws MethodError set!(fi, newval) + continue + end - fi.ps[sym] = newval - @test get(fi) == newval - @test fi.counter[] == 3 - fi.ps[sym] = oldval - @test get(fi) == oldval - @test fi.counter[] == 4 + @test fi.counter[] == 0 + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test fi.counter[] == 1 - if check_inference - @inferred get(p) - end - @test get(p) == oldval - if check_inference - @inferred set!(p, newval) - else - set!(p, newval) + @test get(fi) == newval + set!(fi, oldval) + @test get(fi) == oldval + @test fi.counter[] == 2 + + fi.ps[sym] = newval + @test get(fi) == newval + @test fi.counter[] == 3 + fi.ps[sym] = oldval + @test get(fi) == oldval + @test fi.counter[] == 4 + + if check_inference + @inferred get(p) + end + @test get(p) == oldval + if check_inference + @inferred set!(p, newval) + else + set!(p, newval) + end + @test get(p) == newval + set!(p, oldval) + @test get(p) == oldval + @test fi.counter[] == 4 + fi.counter[] = 0 end - @test get(p) == newval - set!(p, oldval) - @test get(p) == oldval - @test fi.counter[] == 4 - fi.counter[] = 0 -end -for (sym, val) in [ - ([:a, :b, :c], p), - ([:c, :a], p[[3, 1]]), - ((:b, :a), p[[2, 1]]), - ((1, :c), p[[1, 3]]) -] - buffer = zeros(length(sym)) - get = getp(sys, sym) - @inferred get(buffer, fi) - @test buffer == val + for (sym, val) in [ + ([:a, :b, :c], p), + ([:c, :a], p[[3, 1]]), + ((:b, :a), p[[2, 1]]), + ((1, :c), p[[1, 3]]) + ] + buffer = zeros(length(sym)) + get = getp(sys, sym) + @inferred get(buffer, fi) + @test buffer == val + end end struct FakeSolution