From 1cb0ec200b3fa3d01e8eef551700bb9154d96c9b Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 5 Feb 2024 09:41:06 -0500 Subject: [PATCH] possibly fix SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym) This feels deeply wrong to me. --- src/scimlfunctions.jl | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 1027aaea3..d3464cdcd 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2324,7 +2324,7 @@ function ODEFunction{iip, specialize}(f; typeof(_colorvec), typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, - Wfact_t, W_prototype, paramjac, + Wfact_t, W_prototype, paramjac, observed, _colorvec, sys) else ODEFunction{iip, specialize, @@ -2336,7 +2336,7 @@ function ODEFunction{iip, specialize}(f; typeof(_colorvec), typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, - Wfact_t, W_prototype, paramjac, + Wfact_t, W_prototype, paramjac, observed, _colorvec, sys) end end @@ -2848,7 +2848,7 @@ function unwrapped_f(f::SDEFunction, newf = unwrapped_f(f.f), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype), typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), - typeof(f.paramjac), typeof(f.ggprime), + typeof(f.paramjac), typeof(f.ggprime), typeof(f.observed), typeof(f.colorvec), typeof(f.sys)}(newf, newg, f.mass_matrix, f.analytic, @@ -2943,7 +2943,7 @@ function SplitSDEFunction{iip, specialize}(f1, f2, g; typeof(colorvec), typeof(sys)}(f1, f2, g, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, - Wfact, Wfact_t, paramjac, + Wfact, Wfact_t, paramjac, observed, colorvec, sys) end end @@ -3120,7 +3120,7 @@ function RODEFunction{iip, specialize}(f; typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), - typeof(paramjac), + typeof(paramjac), typeof(observed), typeof(_colorvec), typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, @@ -3473,7 +3473,7 @@ function SDDEFunction{iip, specialize}(f, g; jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, - paramjac, ggprime, + paramjac, ggprime, observed, _colorvec, sys) end end @@ -3566,7 +3566,7 @@ function NonlinearFunction{iip, specialize}(f; typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), - typeof(Wfact_t), typeof(paramjac), + typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), typeof(sys), typeof(resid_prototype)}(_f, mass_matrix, analytic, tgrad, jac, @@ -3935,7 +3935,7 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing -function has_syms(f::AbstractSciMLFunction) +function has_syms(f::AbstractSciMLFunction) if __has_syms(f) f.syms !== nothing else @@ -4051,9 +4051,21 @@ SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_obse function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym) if has_observed(fn) if is_time_dependent(fn) - return (u, p, t) -> fn.observed(sym, u, p, t) + return if hasmethod(fn.observed, Tuple{typeof(sym)}) + fn.observed(sym) + else + let obs = fn.observed, sym = sym + (u, p, t) -> obs(sym, u, p, t) + end + end else - return (u, p) -> fn.observed(sym, u, p) + return if hasmethod(fn.observed, Tuple{typeof(sym)}) + fn.observed(sym) + else + let obs = fn.observed, sym = sym + (u, p) -> obs(sym, u, p) + end + end end end error("SciMLFunction does not have observed")