Skip to content

Commit

Permalink
fix: add dedicated code paths for non-scalarized array symbolics in g…
Browse files Browse the repository at this point in the history
…etu/getp
  • Loading branch information
AayushSabharwal committed Apr 16, 2024
1 parent 90c9138 commit cca9f26
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any},
sys, NotSymbolic(), NotSymbolic(), idx)
return _getp(sys, NotSymbolic(), NotSymbolic(), idx)
end

for (t1, t2) in [
Expand Down Expand Up @@ -229,6 +228,11 @@ for (t1, t2) in [
end

function _getp(sys, ::ArraySymbolic, ::NotSymbolic, p)
if is_parameter(sys, p)
idx = parameter_index(sys, p)
return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any},
sys, NotSymbolic(), NotSymbolic(), idx)
end
return getp(sys, collect(p))
end

Expand Down Expand Up @@ -294,5 +298,9 @@ for (t1, t2) in [
end

function _setp(sys, ::ArraySymbolic, ::NotSymbolic, p)
if is_parameter(sys, p)
idx = parameter_index(sys, p)
return setp(sys, idx; run_hook = false)
end
return setp(sys, collect(p); run_hook = false)
end
12 changes: 12 additions & 0 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ for (t1, t2) in [
end

function _getu(sys, ::ArraySymbolic, ::NotSymbolic, sym)
if is_variable(sys, sym)
idx = variable_index(sys, sym)
return getu(sys, idx)
elseif is_parameter(sys, sym)
return getp(sys, sym)
end
return getu(sys, collect(sym))
end

Expand Down Expand Up @@ -295,5 +301,11 @@ for (t1, t2) in [
end

function _setu(sys, ::ArraySymbolic, ::NotSymbolic, sym)
if is_variable(sys, sym)
idx = variable_index(sys, sym)
return setu(sys, idx)
elseif is_parameter(sys, sym)
return setp(sys, sym)
end
return setu(sys, collect(sym))
end

0 comments on commit cca9f26

Please sign in to comment.