Skip to content

Commit

Permalink
possibly fix SymbolicIndexingInterface.observed(fn::AbstractSciMLFunc…
Browse files Browse the repository at this point in the history
…tion, sym)

This feels deeply wrong to me.
  • Loading branch information
oscardssmith committed Feb 5, 2024
1 parent 3912dff commit 1cb0ec2
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2324,7 +2324,7 @@ function ODEFunction{iip, specialize}(f;
typeof(_colorvec),
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys)
else
ODEFunction{iip, specialize,
Expand All @@ -2336,7 +2336,7 @@ function ODEFunction{iip, specialize}(f;
typeof(_colorvec),
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys)
end
end
Expand Down Expand Up @@ -2848,7 +2848,7 @@ function unwrapped_f(f::SDEFunction, newf = unwrapped_f(f.f),
typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad),
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t),
typeof(f.paramjac), typeof(f.ggprime),
typeof(f.paramjac), typeof(f.ggprime),
typeof(f.observed), typeof(f.colorvec), typeof(f.sys)}(newf, newg,
f.mass_matrix,
f.analytic,
Expand Down Expand Up @@ -2943,7 +2943,7 @@ function SplitSDEFunction{iip, specialize}(f1, f2, g;
typeof(colorvec),
typeof(sys)}(f1, f2, g, mass_matrix, _func_cache, analytic,
tgrad, jac, jvp, vjp, jac_prototype, sparsity,
Wfact, Wfact_t, paramjac,
Wfact, Wfact_t, paramjac,
observed, colorvec, sys)
end
end
Expand Down Expand Up @@ -3120,7 +3120,7 @@ function RODEFunction{iip, specialize}(f;
typeof(analytic), typeof(tgrad),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
typeof(paramjac),
typeof(paramjac),
typeof(observed), typeof(_colorvec),
typeof(sys)}(_f, mass_matrix, analytic, tgrad,
jac, jvp, vjp, jac_prototype, sparsity,
Expand Down Expand Up @@ -3473,7 +3473,7 @@ function SDDEFunction{iip, specialize}(f, g;
jvp, vjp, jac_prototype,
sparsity, Wfact,
Wfact_t,
paramjac, ggprime,
paramjac, ggprime,
observed, _colorvec, sys)
end
end
Expand Down Expand Up @@ -3566,7 +3566,7 @@ function NonlinearFunction{iip, specialize}(f;
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(sparsity), typeof(Wfact),
typeof(Wfact_t), typeof(paramjac),
typeof(Wfact_t), typeof(paramjac),
typeof(observed),
typeof(_colorvec), typeof(sys), typeof(resid_prototype)}(_f, mass_matrix,
analytic, tgrad, jac,
Expand Down Expand Up @@ -3935,7 +3935,7 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing
has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing
has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing
has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing
function has_syms(f::AbstractSciMLFunction)
function has_syms(f::AbstractSciMLFunction)

Check warning on line 3938 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L3938

Added line #L3938 was not covered by tests
if __has_syms(f)
f.syms !== nothing
else
Expand Down Expand Up @@ -4051,9 +4051,21 @@ SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_obse
function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym)
if has_observed(fn)
if is_time_dependent(fn)
return (u, p, t) -> fn.observed(sym, u, p, t)
return if hasmethod(fn.observed, Tuple{typeof(sym)})
fn.observed(sym)

Check warning on line 4055 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4054-L4055

Added lines #L4054 - L4055 were not covered by tests
else
let obs = fn.observed, sym = sym
(u, p, t) -> obs(sym, u, p, t)

Check warning on line 4058 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4057-L4058

Added lines #L4057 - L4058 were not covered by tests
end
end
else
return (u, p) -> fn.observed(sym, u, p)
return if hasmethod(fn.observed, Tuple{typeof(sym)})
fn.observed(sym)

Check warning on line 4063 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4062-L4063

Added lines #L4062 - L4063 were not covered by tests
else
let obs = fn.observed, sym = sym
(u, p) -> obs(sym, u, p)

Check warning on line 4066 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4065-L4066

Added lines #L4065 - L4066 were not covered by tests
end
end
end
end
error("SciMLFunction does not have observed")
Expand Down

0 comments on commit 1cb0ec2

Please sign in to comment.