Skip to content

Commit

Permalink
Merge pull request #2796 from AayushSabharwal/as/common-obs
Browse files Browse the repository at this point in the history
refactor: use common implementation of `observedfun`
  • Loading branch information
ChrisRackauckas authored Jun 12, 2024
2 parents 8295ed1 + 4ece767 commit 9050e70
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 136 deletions.
24 changes: 24 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,30 @@ end
###
### System utils
###
struct ObservedFunctionCache{S}
sys::S
dict::Dict{Any, Any}
end

function ObservedFunctionCache(sys)
return ObservedFunctionCache(sys, Dict())
let sys = sys, dict = Dict()
function generated_observed(obsvar, args...)
end
end
end

function (ofc::ObservedFunctionCache)(obsvar, args...)
obs = get!(ofc.dict, value(obsvar)) do
SymbolicIndexingInterface.observed(ofc.sys, obsvar)
end
if args === ()
return obs
else
return obs(args...)
end
end

function push_vars!(stmt, name, typ, vars)
isempty(vars) && return
vars_expr = Expr(:macrocall, typ, nothing)
Expand Down
94 changes: 10 additions & 84 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,82 +404,25 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,

obs = observed(sys)
observedfun = if steady_state
let sys = sys, dict = Dict(), ps = ps
let sys = sys, dict = Dict()
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar)
SymbolicIndexingInterface.observed(sys, obsvar)
end
if args === ()
let obs = obs, ps_T = typeof(ps)
(u, p, t = Inf) -> if p isa MTKParameters
obs(u, p..., t)
elseif ps_T <: Tuple
obs(u, p..., t)
else
obs(u, p, t)
end
return let obs = obs
fn1(u, p, t = Inf) = obs(u, p, t)
fn1
end
elseif length(args) == 2
return obs(args..., Inf)
else
if args[2] isa MTKParameters
if length(args) == 2
u, p = args
obs(u, p..., Inf)
else
u, p, t = args
obs(u, p..., t)
end
elseif ps isa Tuple
if length(args) == 2
u, p = args
obs(u, p..., Inf)
else
u, p, t = args
obs(u, p..., t)
end
else
if length(args) == 2
u, p = args
obs(u, p, Inf)
else
u, p, t = args
obs(u, p, t)
end
end
return obs(args...)
end
end
end
else
let sys = sys, dict = Dict(), ps = ps
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys,
obsvar;
checkbounds = checkbounds,
ps)
end
if args === ()
let obs = obs, ps_T = typeof(ps)
(u, p, t) -> if p isa MTKParameters
obs(u, p..., t)
elseif ps_T <: Tuple
obs(u, p..., t)
else
obs(u, p, t)
end
end
else
u, p, t = args
if p isa MTKParameters
u, p, t = args
obs(u, p..., t)
elseif ps isa Tuple # split parameters
obs(u, p..., t)
else
obs(args...)
end
end
end
end
ObservedFunctionCache(sys)
end

jac_prototype = if sparse
Expand Down Expand Up @@ -571,24 +514,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
_jac = nothing
end

obs = observed(sys)
observedfun = let sys = sys, dict = Dict()
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
end
if args === ()
let obs = obs
fun(u, p, t) = obs(u, p, t)
fun(u, p::MTKParameters, t) = obs(u, p..., t)
fun
end
else
u, p, t = args
p isa MTKParameters ? obs(u, p..., t) : obs(u, p, t)
end
end
end
observedfun = ObservedFunctionCache(sys)

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
Expand Down
14 changes: 1 addition & 13 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -484,19 +484,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
M = calculate_massmatrix(sys)
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)

obs = observed(sys)
observedfun = let sys = sys, dict = Dict()
function generated_observed(obsvar, u, p, t)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
end
if p isa MTKParameters
obs(u, p..., t)
else
obs(u, p, t)
end
end
end
observedfun = ObservedFunctionCache(sys)

SDEFunction{iip}(f, g,
sys = sys,
Expand Down
9 changes: 1 addition & 8 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
end

observedfun = let sys = sys, dict = Dict()
function generate_observed(obsvar, u, p, t)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar)
end
p isa MTKParameters ? obs(u, p..., t) : obs(u, p, t)
end
end
observedfun = ObservedFunctionCache(sys)

DiscreteFunction{iip, specialize}(f;
sys = sys,
Expand Down
11 changes: 1 addition & 10 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,16 +345,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,

f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

# just taken from abstractodesystem.jl for ODEFunction def
obs = observed(sys)
observedfun = let sys = sys, dict = Dict()
function generated_observed(obsvar, u, p, t)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
end
p isa MTKParameters ? obs(u, p..., t) : obs(u, p, t)
end
end
observedfun = ObservedFunctionCache(sys)

df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun)
DiscreteProblem(df, u0, tspan, p; kwargs...)
Expand Down
22 changes: 1 addition & 21 deletions src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,27 +337,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
hess_prototype = nothing
end

observedfun = let sys = sys, dict = Dict()
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar)
end
if args === ()
let obs = obs
_obs(u, p) = obs(u, p)
_obs(u, p::MTKParameters) = obs(u, p...)
_obs
end
else
u, p = args
if p isa MTKParameters
obs(u, p...)
else
obs(u, p)
end
end
end
end
observedfun = ObservedFunctionCache(sys)

if length(cstr) > 0
@named cons_sys = ConstraintsSystem(cstr, dvs, ps)
Expand Down

0 comments on commit 9050e70

Please sign in to comment.