From 4ece7677ce129fd561421afd359231a625c28c5d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 12 Jun 2024 15:34:24 +0530 Subject: [PATCH] refactor: use common implementation of `observedfun` --- src/systems/abstractsystem.jl | 24 +++++ src/systems/diffeqs/abstractodesystem.jl | 94 ++----------------- src/systems/diffeqs/sdesystem.jl | 14 +-- .../discrete_system/discrete_system.jl | 9 +- src/systems/jumps/jumpsystem.jl | 11 +-- .../optimization/optimizationsystem.jl | 22 +---- 6 files changed, 38 insertions(+), 136 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 66d4f6801c..7c842be5a5 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1192,6 +1192,30 @@ end ### ### System utils ### +struct ObservedFunctionCache{S} + sys::S + dict::Dict{Any, Any} +end + +function ObservedFunctionCache(sys) + return ObservedFunctionCache(sys, Dict()) + let sys = sys, dict = Dict() + function generated_observed(obsvar, args...) + end + end +end + +function (ofc::ObservedFunctionCache)(obsvar, args...) + obs = get!(ofc.dict, value(obsvar)) do + SymbolicIndexingInterface.observed(ofc.sys, obsvar) + end + if args === () + return obs + else + return obs(args...) + end +end + function push_vars!(stmt, name, typ, vars) isempty(vars) && return vars_expr = Expr(:macrocall, typ, nothing) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index da701560b9..292052ec3e 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -404,82 +404,25 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, obs = observed(sys) observedfun = if steady_state - let sys = sys, dict = Dict(), ps = ps + let sys = sys, dict = Dict() function generated_observed(obsvar, args...) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar) + SymbolicIndexingInterface.observed(sys, obsvar) end if args === () - let obs = obs, ps_T = typeof(ps) - (u, p, t = Inf) -> if p isa MTKParameters - obs(u, p..., t) - elseif ps_T <: Tuple - obs(u, p..., t) - else - obs(u, p, t) - end + return let obs = obs + fn1(u, p, t = Inf) = obs(u, p, t) + fn1 end + elseif length(args) == 2 + return obs(args..., Inf) else - if args[2] isa MTKParameters - if length(args) == 2 - u, p = args - obs(u, p..., Inf) - else - u, p, t = args - obs(u, p..., t) - end - elseif ps isa Tuple - if length(args) == 2 - u, p = args - obs(u, p..., Inf) - else - u, p, t = args - obs(u, p..., t) - end - else - if length(args) == 2 - u, p = args - obs(u, p, Inf) - else - u, p, t = args - obs(u, p, t) - end - end + return obs(args...) end end end else - let sys = sys, dict = Dict(), ps = ps - function generated_observed(obsvar, args...) - obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, - obsvar; - checkbounds = checkbounds, - ps) - end - if args === () - let obs = obs, ps_T = typeof(ps) - (u, p, t) -> if p isa MTKParameters - obs(u, p..., t) - elseif ps_T <: Tuple - obs(u, p..., t) - else - obs(u, p, t) - end - end - else - u, p, t = args - if p isa MTKParameters - u, p, t = args - obs(u, p..., t) - elseif ps isa Tuple # split parameters - obs(u, p..., t) - else - obs(args...) - end - end - end - end + ObservedFunctionCache(sys) end jac_prototype = if sparse @@ -571,24 +514,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) _jac = nothing end - obs = observed(sys) - observedfun = let sys = sys, dict = Dict() - function generated_observed(obsvar, args...) - obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds) - end - if args === () - let obs = obs - fun(u, p, t) = obs(u, p, t) - fun(u, p::MTKParameters, t) = obs(u, p..., t) - fun - end - else - u, p, t = args - p isa MTKParameters ? obs(u, p..., t) : obs(u, p, t) - end - end - end + observedfun = ObservedFunctionCache(sys) jac_prototype = if sparse uElType = u0 === nothing ? Float64 : eltype(u0) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 05d85ecd6b..223f1b3c0a 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -484,19 +484,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys), M = calculate_massmatrix(sys) _M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M) - obs = observed(sys) - observedfun = let sys = sys, dict = Dict() - function generated_observed(obsvar, u, p, t) - obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds) - end - if p isa MTKParameters - obs(u, p..., t) - else - obs(u, p, t) - end - end - end + observedfun = ObservedFunctionCache(sys) SDEFunction{iip}(f, g, sys = sys, diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 984429a504..18755ebafb 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -330,14 +330,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}( f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t)) end - observedfun = let sys = sys, dict = Dict() - function generate_observed(obsvar, u, p, t) - obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar) - end - p isa MTKParameters ? obs(u, p..., t) : obs(u, p, t) - end - end + observedfun = ObservedFunctionCache(sys) DiscreteFunction{iip, specialize}(f; sys = sys, diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 36b8719e1e..accc0bc39f 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -345,16 +345,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, f = DiffEqBase.DISCRETE_INPLACE_DEFAULT - # just taken from abstractodesystem.jl for ODEFunction def - obs = observed(sys) - observedfun = let sys = sys, dict = Dict() - function generated_observed(obsvar, u, p, t) - obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds) - end - p isa MTKParameters ? obs(u, p..., t) : obs(u, p, t) - end - end + observedfun = ObservedFunctionCache(sys) df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun) DiscreteProblem(df, u0, tspan, p; kwargs...) diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 4494e61074..f017494b13 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -337,27 +337,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, hess_prototype = nothing end - observedfun = let sys = sys, dict = Dict() - function generated_observed(obsvar, args...) - obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar) - end - if args === () - let obs = obs - _obs(u, p) = obs(u, p) - _obs(u, p::MTKParameters) = obs(u, p...) - _obs - end - else - u, p = args - if p isa MTKParameters - obs(u, p...) - else - obs(u, p) - end - end - end - end + observedfun = ObservedFunctionCache(sys) if length(cstr) > 0 @named cons_sys = ConstraintsSystem(cstr, dvs, ps)