Skip to content

Commit

Permalink
Merge pull request #55 from SciML/as/getu-param
Browse files Browse the repository at this point in the history
fix: fix `getu` with parameter symbols
  • Loading branch information
ChrisRackauckas authored Mar 9, 2024
2 parents a7c70c8 + ad18dbf commit 27df802
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,16 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
return getu(sys, idx)
elseif is_parameter(sys, sym)
return let fn = getp(sys, sym)
getter(prob, args...) = fn(prob)
getter
_getter_p(::NotTimeseries, prob) = fn(prob)
function _getter_p(::Timeseries, prob)
[fn(parameter_values_at_state_time(prob, i))
for i in eachindex(current_time(prob))]
end
_getter_p(::Timeseries, prob, i) = fn(parameter_values_at_state_time(prob, i))
let _getter = _getter_p
getter(prob, args...) = _getter(is_timeseries(prob), prob, args...)
getter
end
end
elseif is_independent_variable(sys, sym)
_getter(::IsTimeseriesTrait, prob) = current_time(prob)
Expand Down
19 changes: 19 additions & 0 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ end
function Base.getproperty(fs::FakeSolution, s::Symbol)
s === :ps ? ParameterIndexingProxy(fs) : getfield(fs, s)
end
SymbolicIndexingInterface.state_values(fs::FakeSolution) = fs.u
SymbolicIndexingInterface.current_time(fs::FakeSolution) = fs.t
SymbolicIndexingInterface.symbolic_container(fs::FakeSolution) = fs.sys
SymbolicIndexingInterface.parameter_values(fs::FakeSolution) = fs.p[end]
SymbolicIndexingInterface.parameter_values(fs::FakeSolution, i) = fs.p[end][i]
Expand Down Expand Up @@ -149,3 +151,20 @@ for (sym, val, arrval, check_inference) in [
@test get(fs, sub_inds) == arrval[sub_inds]
end
end

ps = fs.p[2:2:end]
avals = getindex.(ps, 1)
bvals = getindex.(ps, 2)
cvals = getindex.(ps, 3)
for (sym, val, arrval) in [
(:a, p[1], avals),
((:b, :c), p[2:3], tuple.(bvals, cvals)),
([:c, :a], p[[3, 1]], vcat.(cvals, avals))
]
get = getu(sys, sym)
@inferred get(fs)
@test get(fs) == arrval
for i in eachindex(ps)
@test get(fs, i) == arrval[i]
end
end

0 comments on commit 27df802

Please sign in to comment.