From d2b322e6fd1d6929d91e9675007ee9bb290dc459 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 16 Apr 2024 14:28:19 +0530 Subject: [PATCH] fix: add dedicated code paths for non-scalarized array symbolics in getu/getp --- src/parameter_indexing.jl | 9 +++++++++ src/state_indexing.jl | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index b52686d4..faa4551d 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -140,6 +140,11 @@ for (t1, t2) in [ end function _getp(sys, ::ArraySymbolic, ::NotSymbolic, p) + if is_parameter(sys, p) + idx = parameter_index(sys, p) + return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any}, + sys, NotSymbolic(), NotSymbolic(), idx) + end return getp(sys, collect(p)) end @@ -205,5 +210,9 @@ for (t1, t2) in [ end function _setp(sys, ::ArraySymbolic, ::NotSymbolic, p) + if is_parameter(sys, p) + idx = parameter_index(sys, p) + return setp(sys, idx; run_hook = false) + end return setp(sys, collect(p); run_hook = false) end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 8d51fd50..e854b29e 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -165,6 +165,12 @@ for (t1, t2) in [ end function _getu(sys, ::ArraySymbolic, ::NotSymbolic, sym) + if is_variable(sys, sym) + idx = variable_index(sys, sym) + return getu(sys, idx) + elseif is_parameter(sys, sym) + return getp(sys, sym) + end return getu(sys, collect(sym)) end @@ -221,5 +227,11 @@ for (t1, t2) in [ end function _setu(sys, ::ArraySymbolic, ::NotSymbolic, sym) + if is_variable(sys, sym) + idx = variable_index(sys, sym) + return setu(sys, idx) + elseif is_parameter(sys, sym) + return setp(sys, sym) + end return setu(sys, collect(sym)) end