Skip to content

Commit

Permalink
Merge pull request #615 from AayushSabharwal/as/integ-observed
Browse files Browse the repository at this point in the history
fix: bugs related to `observed`
  • Loading branch information
ChrisRackauckas authored Feb 10, 2024
2 parents 1a148f4 + 2a887de commit 3987f50
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 45 deletions.
6 changes: 0 additions & 6 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,6 @@ end

# SymbolicIndexingInterface
SymbolicIndexingInterface.symbolic_container(A::DEIntegrator) = A.f
function SymbolicIndexingInterface.is_observed(A::DEIntegrator, sym)
return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic()
end
function SymbolicIndexingInterface.observed(A::DEIntegrator, sym)
return getobserved(A)(sym)
end
SymbolicIndexingInterface.parameter_values(A::DEIntegrator) = A.p
SymbolicIndexingInterface.state_values(A::DEIntegrator) = A.u
SymbolicIndexingInterface.current_time(A::DEIntegrator) = A.t
Expand Down
9 changes: 2 additions & 7 deletions src/problems/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@ Base.@propagate_inbounds function Base.getproperty(prob::AbstractSciMLProblem, s
end

SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f
function SymbolicIndexingInterface.is_observed(A::AbstractSciMLProblem, sym)
return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic()
end
function SymbolicIndexingInterface.observed(A::AbstractSciMLProblem, sym)
return getobserved(A)(sym)
end

SymbolicIndexingInterface.parameter_values(prob::AbstractSciMLProblem) = prob.p
SymbolicIndexingInterface.state_values(prob::AbstractSciMLProblem) = prob.u0
SymbolicIndexingInterface.current_time(prob::AbstractSciMLProblem) = prob.tspan[1]
Expand All @@ -33,7 +28,7 @@ Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym)
elseif is_independent_variable(prob.f, sym)
return getindepsym(prob)
elseif is_observed(prob.f, sym)
obs = SymbolicIndexingInterface.observed(prob.f, sym)
obs = SymbolicIndexingInterface.observed(prob, sym)
if is_time_dependent(prob.f)
return obs(prob.u0, prob.p, 0.0)
else
Expand Down
24 changes: 12 additions & 12 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2340,7 +2340,7 @@ function ODEFunction{iip, specialize}(f;
typeof(_colorvec),
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys)
else
ODEFunction{iip, specialize,
Expand All @@ -2352,7 +2352,7 @@ function ODEFunction{iip, specialize}(f;
typeof(_colorvec),
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys)
end
end
Expand Down Expand Up @@ -2864,7 +2864,7 @@ function unwrapped_f(f::SDEFunction, newf = unwrapped_f(f.f),
typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad),
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t),
typeof(f.paramjac), typeof(f.ggprime),
typeof(f.paramjac), typeof(f.ggprime),
typeof(f.observed), typeof(f.colorvec), typeof(f.sys)}(newf, newg,
f.mass_matrix,
f.analytic,
Expand Down Expand Up @@ -2959,7 +2959,7 @@ function SplitSDEFunction{iip, specialize}(f1, f2, g;
typeof(colorvec),
typeof(sys)}(f1, f2, g, mass_matrix, _func_cache, analytic,
tgrad, jac, jvp, vjp, jac_prototype, sparsity,
Wfact, Wfact_t, paramjac,
Wfact, Wfact_t, paramjac,
observed, colorvec, sys)
end
end
Expand Down Expand Up @@ -3136,7 +3136,7 @@ function RODEFunction{iip, specialize}(f;
typeof(analytic), typeof(tgrad),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
typeof(paramjac),
typeof(paramjac),
typeof(observed), typeof(_colorvec),
typeof(sys)}(_f, mass_matrix, analytic, tgrad,
jac, jvp, vjp, jac_prototype, sparsity,
Expand Down Expand Up @@ -3489,7 +3489,7 @@ function SDDEFunction{iip, specialize}(f, g;
jvp, vjp, jac_prototype,
sparsity, Wfact,
Wfact_t,
paramjac, ggprime,
paramjac, ggprime,
observed, _colorvec, sys)
end
end
Expand Down Expand Up @@ -3582,7 +3582,7 @@ function NonlinearFunction{iip, specialize}(f;
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(sparsity), typeof(Wfact),
typeof(Wfact_t), typeof(paramjac),
typeof(Wfact_t), typeof(paramjac),
typeof(observed),
typeof(_colorvec), typeof(sys), typeof(resid_prototype)}(_f, mass_matrix,
analytic, tgrad, jac,
Expand Down Expand Up @@ -3951,7 +3951,7 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing
has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing
has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing
has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing
function has_syms(f::AbstractSciMLFunction)
function has_syms(f::AbstractSciMLFunction)
if __has_syms(f)
f.syms !== nothing
else
Expand Down Expand Up @@ -4062,14 +4062,14 @@ end

SymbolicIndexingInterface.symbolic_container(fn::AbstractSciMLFunction) = has_sys(fn) ? fn.sys : SymbolCache()

SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_observed(fn)
SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_sys(fn) ? is_observed(fn.sys, sym) : has_observed(fn)

function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym)
if has_observed(fn)
if is_time_dependent(fn)
return (u, p, t) -> fn.observed(sym, u, p, t)
if hasmethod(fn.observed, Tuple{Any})
return fn.observed(sym)
else
return (u, p) -> fn.observed(sym, u, p)
return (args...) -> fn.observed(sym, args...)
end
end
error("SciMLFunction does not have observed")
Expand Down
7 changes: 0 additions & 7 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,6 @@ SymbolicIndexingInterface.is_independent_variable(::AbstractNoTimeSolution, sym)

SymbolicIndexingInterface.independent_variable_symbols(::AbstractNoTimeSolution) = []

for soltype in [AbstractTimeseriesSolution, AbstractNoTimeSolution]
@eval function SymbolicIndexingInterface.observed(A::$(soltype), sym::Symbol)
has_sys(A.prob.f) || error("Cannot use observed without system")
return SymbolicIndexingInterface.observed(A, getproperty(A.prob.f.sys, sym))
end
end

SymbolicIndexingInterface.is_time_dependent(::AbstractTimeseriesSolution) = true

SymbolicIndexingInterface.is_time_dependent(::AbstractNoTimeSolution) = false
Expand Down
2 changes: 2 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
6 changes: 2 additions & 4 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,12 @@ getx = getu(sprob, x)
gety = getu(sprob, :y)
get_arr = getu(sprob, [x, y])
get_tuple = getu(sprob, (y, z))
# observed doesn't work the same for SDEs
# uncomment get_obs test below when fixed
@test_broken get_obs = getu(sprob, sys.x + sys.z + t + σ)
get_obs = getu(sprob, sys.x + sys.z + t + σ)
@test getx(sprob) == 10.0
@test gety(sprob) == 10.0
@test get_arr(sprob) == [10.0, 10.0]
@test get_tuple(sprob) == (10.0, 1.0)
# @test get_obs(sprob) == 39.0
@test get_obs(sprob) == 22.0

setx! = setu(sprob, x)
sety! = setu(sprob, :y)
Expand Down
29 changes: 20 additions & 9 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,32 @@ p = [lorenz1.σ => 10.0,

tspan = (0.0, 100.0)
prob = ODEProblem(sys_simplified, u0, tspan, p)
integ = init(prob, Rodas4())
sol = solve(prob, Rodas4())

@test_throws Any sol[b]
@test_throws Any sol[b, 1]
@test_throws Any sol[b, 1:5]
@test_throws Any sol[b, [1, 2, 3]]
@test_throws Any sol['a']
@test_throws Any sol['a', 1]
@test_throws Any sol['a', 1:5]
@test_throws Any sol['a', [1, 2, 3]]
@testset "indexing should error" begin
for obj in [prob, integ, sol]
for sym in ['a', :b]
@test_throws Any obj[sym]
@test_throws Any obj[sym, 1]
@test_throws Any obj[sym, 1:5]
@test_throws Any obj[sym, [1, 2, 3]]
end
end
end

@testset "observed shouldn't error" begin
for obj in [prob, integ, sol]
obj[:a]
SymbolicIndexingInterface.observed(obj, :a)
end
end


@test sol[a] isa AbstractVector
@test sol[:a] == sol[a]
@test sol[a, 1] isa Real
@test sol[:a, 1] == sol[a, 1]
@test sol[:a, 1] == sol[a, 1] == prob[a] == prob[:a] == integ[a] == integ[:a] == -1.0
@test sol[a, 1:5] isa AbstractVector
@test sol[:a, 1:5] == sol[a, 1:5]
@test sol[a, [1, 2, 3]] isa AbstractVector
Expand Down

0 comments on commit 3987f50

Please sign in to comment.