From 680276e720259bebd9bbc57b1cae34ea0abcc749 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 7 Aug 2024 15:55:39 +0530 Subject: [PATCH] feat: allow calling parameter observed functions with parameter object --- src/parameter_indexing.jl | 31 ++++++++++++++++++++++++++---- test/parameter_indexing_test.jl | 34 +++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index f64e7d1a..8f4d5ea5 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -208,11 +208,24 @@ for argType in [Union{Int, CartesianIndex}, Colon, AbstractArray{Bool}, Any] end function (gpo::GetParameterObserved{<:Vector})(::NotTimeseries, prob) - gpo.obsfn(parameter_values(prob), current_time(prob)) + # if the method doesn't exist or is an identity function, then `prob` itself + # is the parameter object, so use that and pass `nothing` for the time expecting + # it to not be used + if hasmethod(parameter_values, Tuple{typeof(prob)}) && + (ps = parameter_values(prob)) != prob + gpo.obsfn(ps, current_time(prob)) + else + gpo.obsfn(prob, nothing) + end end function (gpo::GetParameterObserved{<:Vector, true})( buffer::AbstractArray, ::NotTimeseries, prob) - gpo.obsfn(buffer, parameter_values(prob), current_time(prob)) + if hasmethod(parameter_values, Tuple{typeof(prob)}) && + (ps = parameter_values(prob)) != prob + gpo.obsfn(buffer, ps, current_time(prob)) + else + gpo.obsfn(buffer, prob, nothing) + end end function (gpo::GetParameterObserved{<:Vector})(::Timeseries, prob) throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo))) @@ -224,10 +237,20 @@ function (gpo::GetParameterObserved{<:Vector, false})(::AbstractArray, ::Timeser throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo))) end function (gpo::GetParameterObserved)(::NotTimeseries, prob) - gpo.obsfn(parameter_values(prob), current_time(prob)) + if hasmethod(parameter_values, Tuple{typeof(prob)}) && + (ps = parameter_values(prob)) != prob + gpo.obsfn(ps, current_time(prob)) + else + gpo.obsfn(prob, nothing) + end end function (gpo::GetParameterObserved)(buffer::AbstractArray, ::NotTimeseries, prob) - gpo.obsfn(buffer, parameter_values(prob), current_time(prob)) + if hasmethod(parameter_values, Tuple{typeof(prob)}) && + (ps = parameter_values(prob)) != prob + gpo.obsfn(buffer, ps, current_time(prob)) + else + gpo.obsfn(buffer, prob, nothing) + end return buffer end function (gpo::GetParameterObserved)(::Timeseries, prob) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index b2347b29..43d6f5dc 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -130,6 +130,40 @@ for sys in [ end end + for (sym, val, check_inference) in [ + (:(a + b), p[1] + p[2], true), + ([:(a + b), :(a * b)], [p[1] + p[2], p[1] * p[2]], true), + ((:(a + b), :(a * b)), (p[1] + p[2], p[1] * p[2]), true), + ([:(a + c), :(a + b)], [p[1] + p[3], p[1] + p[2]], true) + ] + get = getp(sys, sym) + if check_inference + @inferred get(parameter_values(fi)) + end + @test get(parameter_values(fi)) == val + if sym isa Union{Array, Tuple} + buffer = zeros(length(sym)) + if check_inference + @inferred get(buffer, parameter_values(fi)) + else + get(buffer, parameter_values(fi)) + end + @test buffer == collect(val) + end + end + + for sym in [ + :(a + t), + [:(a + t), :(a * b)], + (:(a + t), :(a * b)) + ] + get = getp(sys, sym) + @test_throws MethodError get(parameter_values(fi)) + if sym isa Union{Array, Tuple} + @test_throws MethodError get(zeros(length(sym)), parameter_values(fi)) + end + end + getter = getp(sys, []) @test getter(fi) == [] getter = getp(sys, ())