From 4142b49d4501838f05ae057f8c89eae5a4dc810e Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 10 Jan 2024 17:41:45 -0500 Subject: [PATCH 1/7] fix for DEintegrator --- src/integrator_interface.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index fb9f883ab..8d4196b7f 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -426,12 +426,6 @@ end # SymbolicIndexingInterface SymbolicIndexingInterface.symbolic_container(A::DEIntegrator) = A.f -function SymbolicIndexingInterface.is_observed(A::DEIntegrator, sym) - return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic() -end -function SymbolicIndexingInterface.observed(A::DEIntegrator, sym) - return getobserved(A)(sym) -end SymbolicIndexingInterface.parameter_values(A::DEIntegrator) = A.p SymbolicIndexingInterface.state_values(A::DEIntegrator) = A.u SymbolicIndexingInterface.current_time(A::DEIntegrator) = A.t From 07c81814349e50039027c82ae5a4efd185a5f5d8 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 10 Jan 2024 19:51:26 -0500 Subject: [PATCH 2/7] add tests --- test/downstream/Project.toml | 1 + test/downstream/symbol_indexing.jl | 21 ++++++++++++--------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 21d07c482..9464f940b 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -9,6 +9,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 76ce5afc4..9cb973373 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -39,21 +39,24 @@ p = [lorenz1.σ => 10.0, tspan = (0.0, 100.0) prob = ODEProblem(sys_simplified, u0, tspan, p) +integ = init(prob, Rodas4()) sol = solve(prob, Rodas4()) -@test_throws Any sol[b] -@test_throws Any sol[b, 1] -@test_throws Any sol[b, 1:5] -@test_throws Any sol[b, [1, 2, 3]] -@test_throws Any sol['a'] -@test_throws Any sol['a', 1] -@test_throws Any sol['a', 1:5] -@test_throws Any sol['a', [1, 2, 3]] +@testset "indexing should error" begin + for obj in [prob, integ, sol] + for sym in ['a', :b] + @test_throws Any obj[sym] + @test_throws Any obj[sym, 1] + @test_throws Any obj[sym, 1:5] + @test_throws Any obj[sym, [1, 2, 3]] + end + end +end @test sol[a] isa AbstractVector @test sol[:a] == sol[a] @test sol[a, 1] isa Real -@test sol[:a, 1] == sol[a, 1] +@test sol[:a, 1] == sol[a, 1] == prob[a] == prob[:a] == integ[a] == integ[:a] == -1.0 @test sol[a, 1:5] isa AbstractVector @test sol[:a, 1:5] == sol[a, 1:5] @test sol[a, [1, 2, 3]] isa AbstractVector From cf41e4b7d026041823109756660cda65f85869ae Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 11 Jan 2024 11:35:13 -0500 Subject: [PATCH 3/7] fix prob and add test --- src/problems/problem_interface.jl | 7 +------ test/downstream/symbol_indexing.jl | 8 ++++++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/problems/problem_interface.jl b/src/problems/problem_interface.jl index 9aa60f7dc..81f186cac 100644 --- a/src/problems/problem_interface.jl +++ b/src/problems/problem_interface.jl @@ -6,12 +6,7 @@ Base.@propagate_inbounds function Base.getproperty(prob::AbstractSciMLProblem, s end SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f -function SymbolicIndexingInterface.is_observed(A::AbstractSciMLProblem, sym) - return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic() -end -function SymbolicIndexingInterface.observed(A::AbstractSciMLProblem, sym) - return getobserved(A)(sym) -end + SymbolicIndexingInterface.parameter_values(prob::AbstractSciMLProblem) = prob.p SymbolicIndexingInterface.state_values(prob::AbstractSciMLProblem) = prob.u0 SymbolicIndexingInterface.current_time(prob::AbstractSciMLProblem) = prob.tspan[1] diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 9cb973373..a937c68bf 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -53,6 +53,14 @@ sol = solve(prob, Rodas4()) end end +@testset "observed shouldn't error" begin + for obj in [prob, integ, sol] + obj[:a] + observed(obh, :a) + end +end + + @test sol[a] isa AbstractVector @test sol[:a] == sol[a] @test sol[a, 1] isa Real From 3912dffa432d27dfc74db7d1c6d6527d321c836e Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 12 Jan 2024 15:04:25 -0500 Subject: [PATCH 4/7] fixes --- src/problems/problem_interface.jl | 2 +- src/solutions/solution_interface.jl | 7 ------- test/downstream/Project.toml | 1 + test/downstream/symbol_indexing.jl | 2 +- 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/problems/problem_interface.jl b/src/problems/problem_interface.jl index 81f186cac..a4213c906 100644 --- a/src/problems/problem_interface.jl +++ b/src/problems/problem_interface.jl @@ -28,7 +28,7 @@ Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym) elseif is_independent_variable(prob.f, sym) return getindepsym(prob) elseif is_observed(prob.f, sym) - obs = SymbolicIndexingInterface.observed(prob.f, sym) + obs = SymbolicIndexingInterface.observed(prob, sym) if is_time_dependent(prob.f) return obs(prob.u0, prob.p, 0.0) else diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index f9b923efe..9d5dbaf50 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -47,13 +47,6 @@ SymbolicIndexingInterface.is_independent_variable(::AbstractNoTimeSolution, sym) SymbolicIndexingInterface.independent_variable_symbols(::AbstractNoTimeSolution) = [] -for soltype in [AbstractTimeseriesSolution, AbstractNoTimeSolution] - @eval function SymbolicIndexingInterface.observed(A::$(soltype), sym::Symbol) - has_sys(A.prob.f) || error("Cannot use observed without system") - return SymbolicIndexingInterface.observed(A, getproperty(A.prob.f.sys, sym)) - end -end - SymbolicIndexingInterface.is_time_dependent(::AbstractTimeseriesSolution) = true SymbolicIndexingInterface.is_time_dependent(::AbstractNoTimeSolution) = false diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 9464f940b..4cd495893 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -6,6 +6,7 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index a937c68bf..810f3801d 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -56,7 +56,7 @@ end @testset "observed shouldn't error" begin for obj in [prob, integ, sol] obj[:a] - observed(obh, :a) + SymbolicIndexingInterface.observed(obj, :a) end end From 1cb0ec200b3fa3d01e8eef551700bb9154d96c9b Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 5 Feb 2024 09:41:06 -0500 Subject: [PATCH 5/7] possibly fix SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym) This feels deeply wrong to me. --- src/scimlfunctions.jl | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 1027aaea3..d3464cdcd 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2324,7 +2324,7 @@ function ODEFunction{iip, specialize}(f; typeof(_colorvec), typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, - Wfact_t, W_prototype, paramjac, + Wfact_t, W_prototype, paramjac, observed, _colorvec, sys) else ODEFunction{iip, specialize, @@ -2336,7 +2336,7 @@ function ODEFunction{iip, specialize}(f; typeof(_colorvec), typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, - Wfact_t, W_prototype, paramjac, + Wfact_t, W_prototype, paramjac, observed, _colorvec, sys) end end @@ -2848,7 +2848,7 @@ function unwrapped_f(f::SDEFunction, newf = unwrapped_f(f.f), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype), typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), - typeof(f.paramjac), typeof(f.ggprime), + typeof(f.paramjac), typeof(f.ggprime), typeof(f.observed), typeof(f.colorvec), typeof(f.sys)}(newf, newg, f.mass_matrix, f.analytic, @@ -2943,7 +2943,7 @@ function SplitSDEFunction{iip, specialize}(f1, f2, g; typeof(colorvec), typeof(sys)}(f1, f2, g, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, - Wfact, Wfact_t, paramjac, + Wfact, Wfact_t, paramjac, observed, colorvec, sys) end end @@ -3120,7 +3120,7 @@ function RODEFunction{iip, specialize}(f; typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), - typeof(paramjac), + typeof(paramjac), typeof(observed), typeof(_colorvec), typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, @@ -3473,7 +3473,7 @@ function SDDEFunction{iip, specialize}(f, g; jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, - paramjac, ggprime, + paramjac, ggprime, observed, _colorvec, sys) end end @@ -3566,7 +3566,7 @@ function NonlinearFunction{iip, specialize}(f; typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), - typeof(Wfact_t), typeof(paramjac), + typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), typeof(sys), typeof(resid_prototype)}(_f, mass_matrix, analytic, tgrad, jac, @@ -3935,7 +3935,7 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing -function has_syms(f::AbstractSciMLFunction) +function has_syms(f::AbstractSciMLFunction) if __has_syms(f) f.syms !== nothing else @@ -4051,9 +4051,21 @@ SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_obse function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym) if has_observed(fn) if is_time_dependent(fn) - return (u, p, t) -> fn.observed(sym, u, p, t) + return if hasmethod(fn.observed, Tuple{typeof(sym)}) + fn.observed(sym) + else + let obs = fn.observed, sym = sym + (u, p, t) -> obs(sym, u, p, t) + end + end else - return (u, p) -> fn.observed(sym, u, p) + return if hasmethod(fn.observed, Tuple{typeof(sym)}) + fn.observed(sym) + else + let obs = fn.observed, sym = sym + (u, p) -> obs(sym, u, p) + end + end end end error("SciMLFunction does not have observed") From d8823c9b573d952da7062a8e233fb8e5c5ac0a0d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 6 Feb 2024 16:55:51 +0530 Subject: [PATCH 6/7] fix: bugs related to `observed` --- src/scimlfunctions.jl | 21 +++++---------------- test/downstream/problem_interface.jl | 6 ++---- 2 files changed, 7 insertions(+), 20 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index d3464cdcd..aca1b304c 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4046,26 +4046,15 @@ end SymbolicIndexingInterface.symbolic_container(fn::AbstractSciMLFunction) = has_sys(fn) ? fn.sys : SymbolCache() -SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_observed(fn) +SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_sys(fn) ? is_observed(fn.sys, sym) : has_observed(fn) function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym) if has_observed(fn) - if is_time_dependent(fn) - return if hasmethod(fn.observed, Tuple{typeof(sym)}) - fn.observed(sym) - else - let obs = fn.observed, sym = sym - (u, p, t) -> obs(sym, u, p, t) - end - end + if hasmethod(fn.observed, Tuple{Any}) + @show sym + return fn.observed(sym) else - return if hasmethod(fn.observed, Tuple{typeof(sym)}) - fn.observed(sym) - else - let obs = fn.observed, sym = sym - (u, p) -> obs(sym, u, p) - end - end + return (args...) -> fn.observed(sym, args...) end end error("SciMLFunction does not have observed") diff --git a/test/downstream/problem_interface.jl b/test/downstream/problem_interface.jl index 15f587d64..90308ee26 100644 --- a/test/downstream/problem_interface.jl +++ b/test/downstream/problem_interface.jl @@ -140,14 +140,12 @@ getx = getu(sprob, x) gety = getu(sprob, :y) get_arr = getu(sprob, [x, y]) get_tuple = getu(sprob, (y, z)) -# observed doesn't work the same for SDEs -# uncomment get_obs test below when fixed -@test_broken get_obs = getu(sprob, sys.x + sys.z + t + σ) +get_obs = getu(sprob, sys.x + sys.z + t + σ) @test getx(sprob) == 10.0 @test gety(sprob) == 10.0 @test get_arr(sprob) == [10.0, 10.0] @test get_tuple(sprob) == (10.0, 1.0) -# @test get_obs(sprob) == 39.0 +@test get_obs(sprob) == 22.0 setx! = setu(sprob, x) sety! = setu(sprob, :y) From 2a887de8aa126eb978d51ecf3993c9189583fab3 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 8 Feb 2024 10:48:04 +0100 Subject: [PATCH 7/7] Update src/scimlfunctions.jl --- src/scimlfunctions.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index aca1b304c..f5ad463b8 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4051,7 +4051,6 @@ SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_sys( function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym) if has_observed(fn) if hasmethod(fn.observed, Tuple{Any}) - @show sym return fn.observed(sym) else return (args...) -> fn.observed(sym, args...)