From 4f5758c9836d268c5adbb8474cea207582b558cb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 21 Feb 2024 17:48:47 +0530 Subject: [PATCH] feat: add support for inplace getp --- src/parameter_indexing.jl | 18 ++++++++++++++---- test/parameter_indexing_test.jl | 12 ++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index eaa5f2c..b58fb79 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -57,8 +57,10 @@ end function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) - return function getter(sol) - return parameter_values(sol, idx) + return let idx = idx + function getter(sol) + return parameter_values(sol, idx) + end end end @@ -70,8 +72,16 @@ for (t1, t2) in [ @eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2) getters = getp.((sys,), p) - return function getter(sol) - map(g -> g(sol), getters) + return let getters = getters + function getter(sol) + map(g -> g(sol), getters) + end + function getter(buffer, sol) + for (i, g) in zip(eachindex(buffer), getters) + buffer[i] = g(sol) + end + buffer + end end end end diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index a2646f8..adac52f 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -56,3 +56,15 @@ for (sym, oldval, newval, check_inference) in [ set!(p, oldval) @test get(p) == oldval end + +for (sym, val) in [ + ([:a, :b, :c], p), + ([:c, :a], p[[3, 1]]), + ((:b, :a), p[[2, 1]]), + ((1, :c), p[[1, 3]]) +] + buffer = zeros(length(sym)) + get = getp(sys, sym) + @inferred get(buffer, fi) + @test buffer == val +end