From 4142b49d4501838f05ae057f8c89eae5a4dc810e Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 10 Jan 2024 17:41:45 -0500 Subject: [PATCH 1/5] fix for DEintegrator --- src/integrator_interface.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index fb9f883ab..8d4196b7f 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -426,12 +426,6 @@ end # SymbolicIndexingInterface SymbolicIndexingInterface.symbolic_container(A::DEIntegrator) = A.f -function SymbolicIndexingInterface.is_observed(A::DEIntegrator, sym) - return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic() -end -function SymbolicIndexingInterface.observed(A::DEIntegrator, sym) - return getobserved(A)(sym) -end SymbolicIndexingInterface.parameter_values(A::DEIntegrator) = A.p SymbolicIndexingInterface.state_values(A::DEIntegrator) = A.u SymbolicIndexingInterface.current_time(A::DEIntegrator) = A.t From 07c81814349e50039027c82ae5a4efd185a5f5d8 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 10 Jan 2024 19:51:26 -0500 Subject: [PATCH 2/5] add tests --- test/downstream/Project.toml | 1 + test/downstream/symbol_indexing.jl | 21 ++++++++++++--------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 21d07c482..9464f940b 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -9,6 +9,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 76ce5afc4..9cb973373 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -39,21 +39,24 @@ p = [lorenz1.σ => 10.0, tspan = (0.0, 100.0) prob = ODEProblem(sys_simplified, u0, tspan, p) +integ = init(prob, Rodas4()) sol = solve(prob, Rodas4()) -@test_throws Any sol[b] -@test_throws Any sol[b, 1] -@test_throws Any sol[b, 1:5] -@test_throws Any sol[b, [1, 2, 3]] -@test_throws Any sol['a'] -@test_throws Any sol['a', 1] -@test_throws Any sol['a', 1:5] -@test_throws Any sol['a', [1, 2, 3]] +@testset "indexing should error" begin + for obj in [prob, integ, sol] + for sym in ['a', :b] + @test_throws Any obj[sym] + @test_throws Any obj[sym, 1] + @test_throws Any obj[sym, 1:5] + @test_throws Any obj[sym, [1, 2, 3]] + end + end +end @test sol[a] isa AbstractVector @test sol[:a] == sol[a] @test sol[a, 1] isa Real -@test sol[:a, 1] == sol[a, 1] +@test sol[:a, 1] == sol[a, 1] == prob[a] == prob[:a] == integ[a] == integ[:a] == -1.0 @test sol[a, 1:5] isa AbstractVector @test sol[:a, 1:5] == sol[a, 1:5] @test sol[a, [1, 2, 3]] isa AbstractVector From cf41e4b7d026041823109756660cda65f85869ae Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 11 Jan 2024 11:35:13 -0500 Subject: [PATCH 3/5] fix prob and add test --- src/problems/problem_interface.jl | 7 +------ test/downstream/symbol_indexing.jl | 8 ++++++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/problems/problem_interface.jl b/src/problems/problem_interface.jl index 9aa60f7dc..81f186cac 100644 --- a/src/problems/problem_interface.jl +++ b/src/problems/problem_interface.jl @@ -6,12 +6,7 @@ Base.@propagate_inbounds function Base.getproperty(prob::AbstractSciMLProblem, s end SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f -function SymbolicIndexingInterface.is_observed(A::AbstractSciMLProblem, sym) - return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic() -end -function SymbolicIndexingInterface.observed(A::AbstractSciMLProblem, sym) - return getobserved(A)(sym) -end + SymbolicIndexingInterface.parameter_values(prob::AbstractSciMLProblem) = prob.p SymbolicIndexingInterface.state_values(prob::AbstractSciMLProblem) = prob.u0 SymbolicIndexingInterface.current_time(prob::AbstractSciMLProblem) = prob.tspan[1] diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 9cb973373..a937c68bf 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -53,6 +53,14 @@ sol = solve(prob, Rodas4()) end end +@testset "observed shouldn't error" begin + for obj in [prob, integ, sol] + obj[:a] + observed(obh, :a) + end +end + + @test sol[a] isa AbstractVector @test sol[:a] == sol[a] @test sol[a, 1] isa Real From 3912dffa432d27dfc74db7d1c6d6527d321c836e Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 12 Jan 2024 15:04:25 -0500 Subject: [PATCH 4/5] fixes --- src/problems/problem_interface.jl | 2 +- src/solutions/solution_interface.jl | 7 ------- test/downstream/Project.toml | 1 + test/downstream/symbol_indexing.jl | 2 +- 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/problems/problem_interface.jl b/src/problems/problem_interface.jl index 81f186cac..a4213c906 100644 --- a/src/problems/problem_interface.jl +++ b/src/problems/problem_interface.jl @@ -28,7 +28,7 @@ Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym) elseif is_independent_variable(prob.f, sym) return getindepsym(prob) elseif is_observed(prob.f, sym) - obs = SymbolicIndexingInterface.observed(prob.f, sym) + obs = SymbolicIndexingInterface.observed(prob, sym) if is_time_dependent(prob.f) return obs(prob.u0, prob.p, 0.0) else diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index f9b923efe..9d5dbaf50 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -47,13 +47,6 @@ SymbolicIndexingInterface.is_independent_variable(::AbstractNoTimeSolution, sym) SymbolicIndexingInterface.independent_variable_symbols(::AbstractNoTimeSolution) = [] -for soltype in [AbstractTimeseriesSolution, AbstractNoTimeSolution] - @eval function SymbolicIndexingInterface.observed(A::$(soltype), sym::Symbol) - has_sys(A.prob.f) || error("Cannot use observed without system") - return SymbolicIndexingInterface.observed(A, getproperty(A.prob.f.sys, sym)) - end -end - SymbolicIndexingInterface.is_time_dependent(::AbstractTimeseriesSolution) = true SymbolicIndexingInterface.is_time_dependent(::AbstractNoTimeSolution) = false diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 9464f940b..4cd495893 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -6,6 +6,7 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index a937c68bf..810f3801d 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -56,7 +56,7 @@ end @testset "observed shouldn't error" begin for obj in [prob, integ, sol] obj[:a] - observed(obh, :a) + SymbolicIndexingInterface.observed(obj, :a) end end From 1cb0ec200b3fa3d01e8eef551700bb9154d96c9b Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 5 Feb 2024 09:41:06 -0500 Subject: [PATCH 5/5] possibly fix SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym) This feels deeply wrong to me. --- src/scimlfunctions.jl | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 1027aaea3..d3464cdcd 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2324,7 +2324,7 @@ function ODEFunction{iip, specialize}(f; typeof(_colorvec), typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, - Wfact_t, W_prototype, paramjac, + Wfact_t, W_prototype, paramjac, observed, _colorvec, sys) else ODEFunction{iip, specialize, @@ -2336,7 +2336,7 @@ function ODEFunction{iip, specialize}(f; typeof(_colorvec), typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, - Wfact_t, W_prototype, paramjac, + Wfact_t, W_prototype, paramjac, observed, _colorvec, sys) end end @@ -2848,7 +2848,7 @@ function unwrapped_f(f::SDEFunction, newf = unwrapped_f(f.f), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype), typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), - typeof(f.paramjac), typeof(f.ggprime), + typeof(f.paramjac), typeof(f.ggprime), typeof(f.observed), typeof(f.colorvec), typeof(f.sys)}(newf, newg, f.mass_matrix, f.analytic, @@ -2943,7 +2943,7 @@ function SplitSDEFunction{iip, specialize}(f1, f2, g; typeof(colorvec), typeof(sys)}(f1, f2, g, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, - Wfact, Wfact_t, paramjac, + Wfact, Wfact_t, paramjac, observed, colorvec, sys) end end @@ -3120,7 +3120,7 @@ function RODEFunction{iip, specialize}(f; typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), - typeof(paramjac), + typeof(paramjac), typeof(observed), typeof(_colorvec), typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, @@ -3473,7 +3473,7 @@ function SDDEFunction{iip, specialize}(f, g; jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, - paramjac, ggprime, + paramjac, ggprime, observed, _colorvec, sys) end end @@ -3566,7 +3566,7 @@ function NonlinearFunction{iip, specialize}(f; typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), - typeof(Wfact_t), typeof(paramjac), + typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), typeof(sys), typeof(resid_prototype)}(_f, mass_matrix, analytic, tgrad, jac, @@ -3935,7 +3935,7 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing -function has_syms(f::AbstractSciMLFunction) +function has_syms(f::AbstractSciMLFunction) if __has_syms(f) f.syms !== nothing else @@ -4051,9 +4051,21 @@ SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) = has_obse function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym) if has_observed(fn) if is_time_dependent(fn) - return (u, p, t) -> fn.observed(sym, u, p, t) + return if hasmethod(fn.observed, Tuple{typeof(sym)}) + fn.observed(sym) + else + let obs = fn.observed, sym = sym + (u, p, t) -> obs(sym, u, p, t) + end + end else - return (u, p) -> fn.observed(sym, u, p) + return if hasmethod(fn.observed, Tuple{typeof(sym)}) + fn.observed(sym) + else + let obs = fn.observed, sym = sym + (u, p) -> obs(sym, u, p) + end + end end end error("SciMLFunction does not have observed")