Skip to content

Commit

Permalink
Merge pull request #41 from SciML/as/getu-perf
Browse files Browse the repository at this point in the history
refactor: improve getu performance for vectors involving observed quantities
  • Loading branch information
ChrisRackauckas authored Jan 30, 2024
2 parents dab25be + b8c9101 commit 7fdf807
Showing 1 changed file with 58 additions and 18 deletions.
76 changes: 58 additions & 18 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,27 +189,67 @@ for (t1, t2) in [
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
]
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
getters = getu.((sys,), sym)
_call(getter, args...) = getter(args...)
return let getters = getters, _call = _call
_getter(::NotTimeseries, prob) = map(g -> g(prob), getters)
function _getter(::Timeseries, prob)
broadcast(i -> map(g -> _call(g, prob, i), getters),
eachindex(state_values(prob)))
end
function _getter(::Timeseries, prob, i)
return map(g -> _call(g, prob, i), getters)
end
num_observed = count(x -> is_observed(sys, x), sym)
if num_observed <= 1
getters = getu.((sys,), sym)
_call(getter, args...) = getter(args...)
return let getters = getters, _call = _call
_getter(::NotTimeseries, prob) = map(g -> g(prob), getters)
function _getter(::Timeseries, prob)
broadcast(i -> map(g -> _call(g, prob, i), getters),
eachindex(state_values(prob)))
end
function _getter(::Timeseries, prob, i)
return map(g -> _call(g, prob, i), getters)
end

# Need another scope for this to not box `_getter`
let _getter = _getter
function getter(prob)
return _getter(is_timeseries(prob), prob)
# Need another scope for this to not box `_getter`
let _getter = _getter
function getter(prob)
return _getter(is_timeseries(prob), prob)
end
function getter(prob, i)
return _getter(is_timeseries(prob), prob, i)
end
getter
end
end
else
obs = observed(sys, sym isa Tuple ? collect(sym) : sym)
return let obs = obs, is_tuple = sym isa Tuple
function _getter2(::NotTimeseries, prob)
obs(state_values(prob), parameter_values(prob), current_time(prob))
end
function _getter2(::Timeseries, prob)
obs.(state_values(prob), (parameter_values(prob),), current_time(prob))
end
function getter(prob, i)
return _getter(is_timeseries(prob), prob, i)
function _getter2(::Timeseries, prob, i)
obs(state_values(prob, i),
parameter_values(prob),
current_time(prob, i))
end

if is_tuple
let _getter2 = _getter2
function getter2(prob)
Tuple(_getter2(is_timeseries(prob), prob))
end
function getter2(prob, i)
Tuple(_getter2(is_timeseries(prob), prob, i))
end
getter2
end
else
let _getter2 = _getter2
function getter3(prob)
_getter2(is_timeseries(prob), prob)
end
function getter3(prob, i)
_getter2(is_timeseries(prob), prob, i)
end
getter3
end
end
getter
end
end
end
Expand Down

0 comments on commit 7fdf807

Please sign in to comment.