Skip to content

Commit

Permalink
refactor: improve getu performance for vectors involving observed qua…
Browse files Browse the repository at this point in the history
…ntities
  • Loading branch information
AayushSabharwal committed Jan 30, 2024
1 parent dab25be commit ab29ac0
Showing 1 changed file with 56 additions and 18 deletions.
74 changes: 56 additions & 18 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,27 +189,65 @@ for (t1, t2) in [
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
]
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
getters = getu.((sys,), sym)
_call(getter, args...) = getter(args...)
return let getters = getters, _call = _call
_getter(::NotTimeseries, prob) = map(g -> g(prob), getters)
function _getter(::Timeseries, prob)
broadcast(i -> map(g -> _call(g, prob, i), getters),
eachindex(state_values(prob)))
end
function _getter(::Timeseries, prob, i)
return map(g -> _call(g, prob, i), getters)
end
num_observed = count(x -> is_observed(sys, x), sym)
if num_observed <= 1
getters = getu.((sys,), sym)
_call(getter, args...) = getter(args...)
return let getters = getters, _call = _call
_getter(::NotTimeseries, prob) = map(g -> g(prob), getters)
function _getter(::Timeseries, prob)
broadcast(i -> map(g -> _call(g, prob, i), getters),

Check warning on line 199 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L192-L199

Added lines #L192 - L199 were not covered by tests
eachindex(state_values(prob)))
end
function _getter(::Timeseries, prob, i)
return map(g -> _call(g, prob, i), getters)

Check warning on line 203 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L202-L203

Added lines #L202 - L203 were not covered by tests
end

# Need another scope for this to not box `_getter`
let _getter = _getter
function getter(prob)
return _getter(is_timeseries(prob), prob)
# Need another scope for this to not box `_getter`
let _getter = _getter
function getter(prob)
return _getter(is_timeseries(prob), prob)

Check warning on line 209 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L207-L209

Added lines #L207 - L209 were not covered by tests
end
function getter(prob, i)
return _getter(is_timeseries(prob), prob, i)

Check warning on line 212 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L211-L212

Added lines #L211 - L212 were not covered by tests
end
getter

Check warning on line 214 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L214

Added line #L214 was not covered by tests
end
end
else
obs = observed(sys, sym isa Tuple ? collect(sym) : sym)
return let obs = obs, is_tuple = sym isa Tuple
function _getter2(::NotTimeseries, prob)
obs(state_values(prob), parameter_values(prob), current_time(prob))

Check warning on line 221 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L218-L221

Added lines #L218 - L221 were not covered by tests
end
function _getter2(::Timeseries, prob)
obs.(state_values(prob), (parameter_values(prob),), current_time(prob))

Check warning on line 224 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L223-L224

Added lines #L223 - L224 were not covered by tests
end
function getter(prob, i)
return _getter(is_timeseries(prob), prob, i)
function _getter2(::Timeseries, prob, i)
obs(state_values(prob, i), parameter_values(prob), current_time(prob, i))

Check warning on line 227 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L226-L227

Added lines #L226 - L227 were not covered by tests
end

if is_tuple
let _getter2 = _getter2
function getter2(prob)
Tuple(_getter2(is_timeseries(prob), prob))

Check warning on line 233 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L230-L233

Added lines #L230 - L233 were not covered by tests
end
function getter2(prob, i)
Tuple(_getter2(is_timeseries(prob), prob, i))

Check warning on line 236 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L235-L236

Added lines #L235 - L236 were not covered by tests
end
getter2

Check warning on line 238 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L238

Added line #L238 was not covered by tests
end
else
let _getter2 = _getter2
function getter3(prob)
_getter2(is_timeseries(prob), prob)

Check warning on line 243 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L241-L243

Added lines #L241 - L243 were not covered by tests
end
function getter3(prob, i)
_getter2(is_timeseries(prob), prob, i)

Check warning on line 246 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L245-L246

Added lines #L245 - L246 were not covered by tests
end
getter3

Check warning on line 248 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L248

Added line #L248 was not covered by tests
end
end
getter
end
end
end
Expand Down

0 comments on commit ab29ac0

Please sign in to comment.