Skip to content

Commit

Permalink
Merge pull request #93 from SciML/as/param-obs-no-t
Browse files Browse the repository at this point in the history
feat: allow calling parameter observed functions with parameter object
  • Loading branch information
AayushSabharwal authored Aug 7, 2024
2 parents 58e0df3 + 680276e commit fc54236
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 4 deletions.
31 changes: 27 additions & 4 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,24 @@ for argType in [Union{Int, CartesianIndex}, Colon, AbstractArray{Bool}, Any]
end

function (gpo::GetParameterObserved{<:Vector})(::NotTimeseries, prob)
gpo.obsfn(parameter_values(prob), current_time(prob))
# if the method doesn't exist or is an identity function, then `prob` itself
# is the parameter object, so use that and pass `nothing` for the time expecting
# it to not be used
if hasmethod(parameter_values, Tuple{typeof(prob)}) &&
(ps = parameter_values(prob)) != prob
gpo.obsfn(ps, current_time(prob))
else
gpo.obsfn(prob, nothing)
end
end
function (gpo::GetParameterObserved{<:Vector, true})(
buffer::AbstractArray, ::NotTimeseries, prob)
gpo.obsfn(buffer, parameter_values(prob), current_time(prob))
if hasmethod(parameter_values, Tuple{typeof(prob)}) &&
(ps = parameter_values(prob)) != prob
gpo.obsfn(buffer, ps, current_time(prob))
else
gpo.obsfn(buffer, prob, nothing)
end
end
function (gpo::GetParameterObserved{<:Vector})(::Timeseries, prob)
throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo)))
Expand All @@ -224,10 +237,20 @@ function (gpo::GetParameterObserved{<:Vector, false})(::AbstractArray, ::Timeser
throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo)))
end
function (gpo::GetParameterObserved)(::NotTimeseries, prob)
gpo.obsfn(parameter_values(prob), current_time(prob))
if hasmethod(parameter_values, Tuple{typeof(prob)}) &&
(ps = parameter_values(prob)) != prob
gpo.obsfn(ps, current_time(prob))
else
gpo.obsfn(prob, nothing)
end
end
function (gpo::GetParameterObserved)(buffer::AbstractArray, ::NotTimeseries, prob)
gpo.obsfn(buffer, parameter_values(prob), current_time(prob))
if hasmethod(parameter_values, Tuple{typeof(prob)}) &&
(ps = parameter_values(prob)) != prob
gpo.obsfn(buffer, ps, current_time(prob))
else
gpo.obsfn(buffer, prob, nothing)
end
return buffer
end
function (gpo::GetParameterObserved)(::Timeseries, prob)
Expand Down
34 changes: 34 additions & 0 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,40 @@ for sys in [
end
end

for (sym, val, check_inference) in [
(:(a + b), p[1] + p[2], true),
([:(a + b), :(a * b)], [p[1] + p[2], p[1] * p[2]], true),
((:(a + b), :(a * b)), (p[1] + p[2], p[1] * p[2]), true),
([:(a + c), :(a + b)], [p[1] + p[3], p[1] + p[2]], true)
]
get = getp(sys, sym)
if check_inference
@inferred get(parameter_values(fi))
end
@test get(parameter_values(fi)) == val
if sym isa Union{Array, Tuple}
buffer = zeros(length(sym))
if check_inference
@inferred get(buffer, parameter_values(fi))
else
get(buffer, parameter_values(fi))
end
@test buffer == collect(val)
end
end

for sym in [
:(a + t),
[:(a + t), :(a * b)],
(:(a + t), :(a * b))
]
get = getp(sys, sym)
@test_throws MethodError get(parameter_values(fi))
if sym isa Union{Array, Tuple}
@test_throws MethodError get(zeros(length(sym)), parameter_values(fi))
end
end

getter = getp(sys, [])
@test getter(fi) == []
getter = getp(sys, ())
Expand Down

0 comments on commit fc54236

Please sign in to comment.