From 738f9883784a1fafc866cd3712bc974f52ca1d7b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 3 Jan 2024 13:47:37 +0530 Subject: [PATCH 1/6] feat: add support for ParameterIndexingProxy --- Project.toml | 2 +- src/integrator_interface.jl | 2 + src/problems/basic_problems.jl | 2 + src/problems/problem_interface.jl | 7 ++++ src/solutions/dae_solutions.jl | 2 + src/solutions/ode_solutions.jl | 2 + src/solutions/optimization_solutions.jl | 2 + src/solutions/rode_solutions.jl | 2 + src/solutions/solution_interface.jl | 9 +++++ test/downstream/integrator_indexing.jl | 14 +++++++ test/downstream/problem_interface.jl | 54 +++++++++++++++++++------ test/downstream/symbol_indexing.jl | 11 +++-- 12 files changed, 92 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index b2f918c20..59a7f7d9e 100644 --- a/Project.toml +++ b/Project.toml @@ -81,7 +81,7 @@ SciMLOperators = "0.3.7" StaticArrays = "1.7" StaticArraysCore = "1.4" Statistics = "1.9" -SymbolicIndexingInterface = "0.3.2" +SymbolicIndexingInterface = "0.3.3" Tables = "1.11" TruncatedStacktraces = "1.4" Zygote = "0.6.67" diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index d1e3aeb9e..fb9f883ab 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -451,6 +451,8 @@ function Base.getproperty(A::DEIntegrator, sym::Symbol) if sym === :destats && hasfield(typeof(A), :stats) @warn "destats has been deprecated for stats" getfield(A, :stats) + elseif sym === :ps + return ParameterIndexingProxy(A) else return getfield(A, sym) end diff --git a/src/problems/basic_problems.jl b/src/problems/basic_problems.jl index a37ba0c54..d15f6af52 100644 --- a/src/problems/basic_problems.jl +++ b/src/problems/basic_problems.jl @@ -502,6 +502,8 @@ function Base.getproperty(prob::IntegralProblem, name::Symbol) domain = getfield(prob, :domain) lb, ub = domain return ub + elseif name === :ps + return ParameterIndexingProxy(prob) end return Base.getfield(prob, name) end diff --git a/src/problems/problem_interface.jl b/src/problems/problem_interface.jl index d9e74916d..ec3bcd989 100644 --- a/src/problems/problem_interface.jl +++ b/src/problems/problem_interface.jl @@ -1,3 +1,10 @@ +Base.@propagate_inbounds function Base.getproperty(prob::AbstractSciMLProblem, sym::Symbol) + if sym === :ps + return ParameterIndexingProxy(prob) + end + return getfield(prob, sym) +end + SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f function SymbolicIndexingInterface.is_observed(A::AbstractSciMLProblem, sym) return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic() diff --git a/src/solutions/dae_solutions.jl b/src/solutions/dae_solutions.jl index 9852c9349..8cbbc839c 100644 --- a/src/solutions/dae_solutions.jl +++ b/src/solutions/dae_solutions.jl @@ -47,6 +47,8 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractDAESolution, s::Sy if s === :destats Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats") return getfield(x, :stats) + elseif s === :ps + return ParameterIndexingProxy(x) end return getfield(x, s) end diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index c95d2a0cb..270377b63 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -121,6 +121,8 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Sy if s === :destats Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats") return getfield(x, :stats) + elseif s === :ps + return ParameterIndexingProxy(x) end return getfield(x, s) end diff --git a/src/solutions/optimization_solutions.jl b/src/solutions/optimization_solutions.jl index 3331ed513..3098820bd 100644 --- a/src/solutions/optimization_solutions.jl +++ b/src/solutions/optimization_solutions.jl @@ -115,6 +115,8 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractOptimizationSoluti Base.depwarn("`sol.prob` is deprecated. Use getters like `get_p` or `get_syms` on `sol` instead.", "sol.prob") return getfield(x, :cache) + elseif s === :ps + return ParameterIndexingProxy(x) end return getfield(x, s) end diff --git a/src/solutions/rode_solutions.jl b/src/solutions/rode_solutions.jl index e5cb9efe8..bf34520c4 100644 --- a/src/solutions/rode_solutions.jl +++ b/src/solutions/rode_solutions.jl @@ -55,6 +55,8 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractRODESolution, s::S if s === :destats Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats") return getfield(x, :stats) + elseif s === :ps + return ParameterIndexingProxy(x) end return getfield(x, s) end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index dbd59e557..f1c2b948f 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -29,6 +29,15 @@ end # SymbolicIndexingInterface.jl const AbstractSolution = Union{AbstractTimeseriesSolution,AbstractNoTimeSolution} + +Base.@propagate_inbounds function Base.getproperty(A::AbstractSolution, sym::Symbol) + if sym === :ps + return ParameterIndexingProxy(A) + else + return getfield(A, sym) + end +end + SymbolicIndexingInterface.symbolic_container(A::AbstractSolution) = A.prob.f SymbolicIndexingInterface.parameter_values(A::AbstractSolution) = A.prob.p diff --git a/test/downstream/integrator_indexing.jl b/test/downstream/integrator_indexing.jl index cef50c62c..b5f74cba1 100644 --- a/test/downstream/integrator_indexing.jl +++ b/test/downstream/integrator_indexing.jl @@ -22,9 +22,13 @@ integrator = init(oprob, Rodas4()) @test_throws Exception integrator[population_model.a] @test_throws Exception integrator[:a] @test getp(oprob, a)(integrator) == getp(oprob, population_model.a)(integrator) == getp(oprob, :a)(integrator) == 2.0 +@test integrator.ps[a] == integrator.ps[population_model.a] == integrator.ps[:a] == 2.0 @test getp(oprob, b)(integrator) == getp(oprob, population_model.b)(integrator) == getp(oprob, :b)(integrator) == 1.0 +@test integrator.ps[b] == integrator.ps[population_model.b] == integrator.ps[:b] == 1.0 @test getp(oprob, c)(integrator) == getp(oprob, population_model.c)(integrator) == getp(oprob, :c)(integrator) == 1.0 +@test integrator.ps[d] == integrator.ps[population_model.d] == integrator.ps[:d] == 1.0 @test getp(oprob, d)(integrator) == getp(oprob, population_model.d)(integrator) == getp(oprob, :d)(integrator) == 1.0 +@test integrator.ps[d] == integrator.ps[population_model.d] == integrator.ps[:d] == 1.0 @test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 2.0 @test integrator[s2] == integrator[population_model.s2] == integrator[:s2] == 1.0 @@ -42,10 +46,15 @@ step!(integrator, 100.0, true) setp(oprob, a)(integrator, 10.0) @test getp(integrator, a)(integrator) == getp(integrator, population_model.a)(integrator) == getp(integrator, :a)(integrator) == 10.0 +@test integrator.ps[a] == integrator.ps[population_model.a] == integrator.ps[:a] == 10.0 setp(population_model, population_model.b)(integrator, 20.0) @test getp(integrator, b)(integrator) == getp(integrator, population_model.b)(integrator) == getp(integrator, :b)(integrator) == 20.0 +@test integrator.ps[b] == integrator.ps[population_model.b] == integrator.ps[:b] == 20.0 setp(integrator, c)(integrator, 30.0) @test getp(integrator, c)(integrator) == getp(integrator, population_model.c)(integrator) == getp(integrator, :c)(integrator) == 30.0 +@test integrator.ps[c] == integrator.ps[population_model.c] == integrator.ps[:c] == 30.0 +integrator.ps[d] = 40.0 +@test integrator.ps[d] == integrator.ps[population_model.d] == integrator.ps[:d] == 40.0 integrator[s1] = 10.0 @test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 10.0 @@ -330,3 +339,8 @@ set_arr!(integrator, [[1.0, 2.0, 3.0], 1.0]) @test get_arr(integrator) == [[1.0, 2.0, 3.0], 1.0] set_tuple!(integrator, ([2.0, 4.0, 6.0], 2.0)) @test get_tuple(integrator) == ([2.0, 4.0, 6.0], 2.0) +@test getp(sys, p)(integrator) == integrator.ps[p] == [1, 2, 3] +setp(sys, p)(integrator, [4, 5, 6]) +@test getp(sys, p)(integrator) == integrator.ps[p] == [4, 5, 6] +integrator.ps[p] = [7, 8, 9] +@test getp(sys, p)(integrator) == integrator.ps[p] == [7, 8, 9] diff --git a/test/downstream/problem_interface.jl b/test/downstream/problem_interface.jl index 1b667d991..6757cf755 100644 --- a/test/downstream/problem_interface.jl +++ b/test/downstream/problem_interface.jl @@ -33,15 +33,15 @@ oprob = ODEProblem(sys, u0, tspan, p, jac = true) getσ1 = getp(sys, σ) getσ2 = getp(sys, sys.σ) getσ3 = getp(sys, :σ) -@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 28.0 +@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == oprob.ps[σ] == oprob.ps[sys.σ] == oprob.ps[:σ] == 28.0 getρ1 = getp(sys, ρ) getρ2 = getp(sys, sys.ρ) getρ3 = getp(sys, :ρ) -@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 10.0 +@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == oprob.ps[ρ] == oprob.ps[sys.ρ] == oprob.ps[:ρ] == 10.0 getβ1 = getp(sys, β) getβ2 = getp(sys, sys.β) getβ3 = getp(sys, :β) -@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 8 / 3 +@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == oprob.ps[β] == oprob.ps[sys.β] == oprob.ps[:β] == 8 / 3 @test oprob[x] == oprob[sys.x] == oprob[:x] == 1.0 @test oprob[y] == oprob[sys.y] == oprob[:y] == 0.0 @@ -51,13 +51,20 @@ getβ3 = getp(sys, :β) setσ = setp(sys, σ) setσ(oprob, 10.0) -@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 10.0 +@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == oprob.ps[σ] == oprob.ps[sys.σ] == oprob.ps[:σ] == 10.0 setρ = setp(sys, sys.ρ) setρ(oprob, 20.0) -@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 20.0 +@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == oprob.ps[ρ] == oprob.ps[sys.ρ] == oprob.ps[:ρ] == 20.0 setβ = setp(sys, :β) setβ(oprob, 30.0) -@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 30.0 +@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == oprob.ps[β] == oprob.ps[sys.β] == oprob.ps[:β] == 30.0 + +oprob.ps[σ] = 11.0 +@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == oprob.ps[σ] == oprob.ps[sys.σ] == oprob.ps[:σ] == 11.0 +oprob.ps[sys.ρ] = 21.0 +@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == oprob.ps[ρ] == oprob.ps[sys.ρ] == oprob.ps[:ρ] == 21.0 +oprob.ps[:β] = 31.0 +@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == oprob.ps[β] == oprob.ps[sys.β] == oprob.ps[:β] == 31.0 oprob[x] = 10.0 @test oprob[x] == oprob[sys.x] == oprob[:x] == 10.0 @@ -103,22 +110,28 @@ noiseeqs = [0.1 * x, sprob = SDEProblem(noise_sys, u0, (0.0, 100.0), p) u0 -@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 28.0 -@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 10.0 -@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 8 / 3 +@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == sprob.ps[σ] == sprob.ps[sys.σ] == sprob.ps[:σ] == 28.0 +@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == sprob.ps[ρ] == sprob.ps[sys.ρ] == sprob.ps[:ρ] == 10.0 +@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == sprob.ps[β] == sprob.ps[sys.β] == sprob.ps[:β] == 8 / 3 @test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 1.0 @test sprob[y] == sprob[noise_sys.y] == sprob[:y] == 0.0 @test sprob[z] == sprob[noise_sys.z] == sprob[:z] == 0.0 setσ(sprob, 10.0) -@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 10.0 +@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == sprob.ps[σ] == sprob.ps[sys.σ] == sprob.ps[:σ] == 10.0 setρ(sprob, 20.0) -@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 20.0 +@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == sprob.ps[ρ] == sprob.ps[sys.ρ] == sprob.ps[:ρ] == 20.0 setp(noise_sys, noise_sys.ρ)(sprob, 25.0) -@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 25.0 +@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == sprob.ps[ρ] == sprob.ps[sys.ρ] == sprob.ps[:ρ] == 25.0 setβ(sprob, 30.0) -@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 30.0 +@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == sprob.ps[β] == sprob.ps[sys.β] == sprob.ps[:β] == 30.0 +sprob.ps[σ] = 11.0 +@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == sprob.ps[σ] == sprob.ps[sys.σ] == sprob.ps[:σ] == 11.0 +sprob.ps[sys.ρ] = 21.0 +@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == sprob.ps[ρ] == sprob.ps[sys.ρ] == sprob.ps[:ρ] == 21.0 +sprob.ps[:β] = 31.0 +@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == sprob.ps[β] == sprob.ps[sys.β] == sprob.ps[:β] == 31.0 sprob[x] = 10.0 @test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 10.0 @@ -155,3 +168,18 @@ set_tuple!(sprob, 13.0) @test get_tuple(sprob) == (13.0, 13.0) set_tuple!(sprob, [10.0, 10.0]) @test get_tuple(sprob) == (10.0, 10.0) + +using LinearAlgebra +@variables t +sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0 +ps = @parameters p[1:3] = [1, 2, 3] +D = Differential(t) +eqs = [collect(D.(x) .~ x) + D(y) ~ norm(x) * y - x[1]] +@named sys = ODESystem(eqs, t, [sts...;], [ps...;]) +prob = ODEProblem(sys, [], (0, 1.0)) +@test getp(sys, p)(prob) == prob.ps[p] == [1, 2, 3] +setp(sys, p)(prob, [4, 5, 6]) +@test getp(sys, p)(prob) == prob.ps[p] == [4, 5, 6] +prob.ps[p] = [7, 8, 9] +@test getp(sys, p)(prob) == prob.ps[p] == [7, 8, 9] diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 55d9ffbe6..397f5d9b6 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -72,8 +72,10 @@ sol = solve(prob, Rodas4()) @test sol[α, 3] isa Float64 @test length(sol[α, 5:10]) == 6 @test getp(prob, γ)(sol) isa Real -@test getp(prob, γ)(sol) == getp(prob, :γ)(sol) == 2.0 +@test sol.ps[γ] isa Real +@test getp(prob, γ)(sol) == getp(prob, :γ)(sol) == sol.ps[γ] == sol.ps[:γ] == 2.0 @test getp(prob, (lorenz1.σ, lorenz1.ρ))(sol) isa Tuple +@test sol.ps[(lorenz1.σ, lorenz1.ρ)] isa Tuple @test sol[[lorenz1.x, lorenz2.x]] isa Vector{Vector{Float64}} @test length(sol[[lorenz1.x, lorenz2.x]]) == length(sol) @@ -179,8 +181,8 @@ sol10 = sol(0.1, idxs = 2) getx = getu(sys_simplified, lorenz1.x) get_arr = getu(sys_simplified, [lorenz1.x, lorenz2.x]) get_tuple = getu(sys_simplified, (lorenz1.x, lorenz2.x)) -get_obs = getu(sys_simplified, lorenz1.x + lorenz2.x) -get_obs_arr = getu(sys_simplified, [lorenz1.x + lorenz2.x, lorenz1.y + lorenz2.y]) +get_obs = getu(sol, lorenz1.x + lorenz2.x) # can't use sys for observed +get_obs_arr = getu(sol, [lorenz1.x + lorenz2.x, lorenz1.y + lorenz2.y]) l1x_idx = variable_index(sol, lorenz1.x) l2x_idx = variable_index(sol, lorenz2.x) l1y_idx = variable_index(sol, lorenz1.y) @@ -213,6 +215,7 @@ prob = ODEProblem(sys, [], (0, 1.0)) sol = solve(prob, Tsit5()) @test sol[x] isa Vector{<:Vector} @test sol[@nonamespace sys.x] isa Vector{<:Vector} +@test sol.ps[p] == [1, 2, 3] getx = getu(sys, x) get_mix_arr = getu(sys, [x, y]) @@ -251,4 +254,6 @@ sol = solve(prob, Tsit5()) @test sol[y]≈1 atol=1e-3 @test getp(sys, a)(sol) ≈ 1 @test getp(sys, b)(sol) ≈ 100 + @test sol.ps[a] ≈ 1 + @test sol.ps[b] ≈ 100 end From c73c07390a5f9150a29b54dbd69e56c13ea26d52 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 9 Jan 2024 17:48:08 +0530 Subject: [PATCH 2/6] fix: fix bug in problem indexing --- src/problems/problem_interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/problems/problem_interface.jl b/src/problems/problem_interface.jl index ec3bcd989..9aa60f7dc 100644 --- a/src/problems/problem_interface.jl +++ b/src/problems/problem_interface.jl @@ -43,7 +43,7 @@ Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym) error("Invalid indexing of problem: $sym is not a state, parameter, or independent variable") end elseif symbolic_type(sym) == ArraySymbolic() - return map(s -> prob[s], sym) + return map(s -> prob[s], collect(sym)) else sym isa AbstractArray || error("Invalid indexing of problem") return map(s -> prob[s], sym) From 9c562863e72f7ba5e101acecbfb02e3e22f0e22a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 9 Jan 2024 17:48:29 +0530 Subject: [PATCH 3/6] test: fix integrator indexing tests --- test/downstream/integrator_indexing.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/downstream/integrator_indexing.jl b/test/downstream/integrator_indexing.jl index b5f74cba1..1dff154d9 100644 --- a/test/downstream/integrator_indexing.jl +++ b/test/downstream/integrator_indexing.jl @@ -309,7 +309,7 @@ eqs = [collect(D.(x) .~ x) D(y) ~ norm(x) * y - x[1]] @named sys = ODESystem(eqs, t, [sts...;], [ps...;]) prob = ODEProblem(sys, [], (0, 1.0)) -integrator = init(prob, Tsit5()) +integrator = init(prob, Tsit5(), save_everystep = false) @test integrator[x] isa Vector{Float64} @test integrator[@nonamespace sys.x] isa Vector{Float64} @@ -333,8 +333,6 @@ setx!(integrator, [4.0, 5.0, 6.0]) @test getx(integrator) == [4.0, 5.0, 6.0] sety!(integrator, 3.0) @test gety(integrator) == 3.0 -set_arr!(integrator, [1.0, 2.0]) -@test get_arr(integrator) == [[1.0, 1.0, 1.0], 2.0] set_arr!(integrator, [[1.0, 2.0, 3.0], 1.0]) @test get_arr(integrator) == [[1.0, 2.0, 3.0], 1.0] set_tuple!(integrator, ([2.0, 4.0, 6.0], 2.0)) From 00f3e4e0c6a2342b6fb1d0b041790ae596f93c4d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 9 Jan 2024 17:48:39 +0530 Subject: [PATCH 4/6] test: fix problem indexing tests --- test/downstream/problem_interface.jl | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/test/downstream/problem_interface.jl b/test/downstream/problem_interface.jl index 6757cf755..15f587d64 100644 --- a/test/downstream/problem_interface.jl +++ b/test/downstream/problem_interface.jl @@ -82,7 +82,7 @@ get_obs = getu(oprob, sys.x + sys.z + t + σ) @test gety(oprob) == 10.0 @test get_arr(oprob) == [10.0, 10.0] @test get_tuple(oprob) == (10.0, 1.0) -@test get_obs(oprob) == 39.0 +@test get_obs(oprob) == 22.0 setx! = setu(oprob, x) sety! = setu(oprob, :y) @@ -93,12 +93,8 @@ setx!(oprob, 11.0) @test getx(oprob) == 11.0 sety!(oprob, 12.0) @test gety(oprob) == 12.0 -set_arr!(oprob, 10.0) -@test get_arr(oprob) == [10.0, 10.0] set_arr!(oprob, [11.0, 12.0]) @test get_arr(oprob) == [11.0, 12.0] -set_tuple!(oprob, 13.0) -@test get_tuple(oprob) == (13.0, 13.0) set_tuple!(oprob, [10.0, 10.0]) @test get_tuple(oprob) == (10.0, 10.0) @@ -144,12 +140,14 @@ getx = getu(sprob, x) gety = getu(sprob, :y) get_arr = getu(sprob, [x, y]) get_tuple = getu(sprob, (y, z)) -get_obs = getu(sprob, sys.x + sys.z + t + σ) +# observed doesn't work the same for SDEs +# uncomment get_obs test below when fixed +@test_broken get_obs = getu(sprob, sys.x + sys.z + t + σ) @test getx(sprob) == 10.0 @test gety(sprob) == 10.0 @test get_arr(sprob) == [10.0, 10.0] @test get_tuple(sprob) == (10.0, 1.0) -@test get_obs(sprob) == 39.0 +# @test get_obs(sprob) == 39.0 setx! = setu(sprob, x) sety! = setu(sprob, :y) @@ -160,12 +158,8 @@ setx!(sprob, 11.0) @test getx(sprob) == 11.0 sety!(sprob, 12.0) @test gety(sprob) == 12.0 -set_arr!(sprob, 10.0) -@test get_arr(sprob) == [10.0, 10.0] set_arr!(sprob, [11.0, 12.0]) @test get_arr(sprob) == [11.0, 12.0] -set_tuple!(sprob, 13.0) -@test get_tuple(sprob) == (13.0, 13.0) set_tuple!(sprob, [10.0, 10.0]) @test get_tuple(sprob) == (10.0, 10.0) From 0401998e9046b00fca276466b0fbef542ef5b7e8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 9 Jan 2024 17:49:06 +0530 Subject: [PATCH 5/6] test: extensively test getu/setu and their type-stability --- test/downstream/symbol_indexing.jl | 115 ++++++++++++++++++++++++++--- 1 file changed, 104 insertions(+), 11 deletions(-) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 397f5d9b6..fc0fc1675 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -188,11 +188,11 @@ l2x_idx = variable_index(sol, lorenz2.x) l1y_idx = variable_index(sol, lorenz1.y) l2y_idx = variable_index(sol, lorenz2.y) -@test getx(sol) == sol[:, l1x_idx] -@test get_arr(sol) == sol[:, [l1x_idx, l2x_idx]] -@test get_tuple(sol) == tuple.(sol[:, l1x_idx], sol[:, l2x_idx]) -@test get_obs(sol) == sol[:, l1x_idx] + sol[:, l2x_idx] -@test get_obs_arr(sol) == vcat.(sol[:, l1x_idx] + sol[:, l2x_idx], sol[:, l1y_idx] + sol[:, l2y_idx]) +@test getx(sol) == sol[l1x_idx, :] +@test get_arr(sol) == vcat.(sol[l1x_idx, :], sol[l2x_idx, :]) +@test get_tuple(sol) == tuple.(sol[l1x_idx, :], sol[l2x_idx, :]) +@test get_obs(sol) == sol[l1x_idx, :] + sol[l2x_idx, :] +@test get_obs_arr(sol) == vcat.(sol[l1x_idx, :] + sol[l2x_idx, :], sol[l1y_idx, :] + sol[l2y_idx, :]) #= using Plots @@ -217,14 +217,107 @@ sol = solve(prob, Tsit5()) @test sol[@nonamespace sys.x] isa Vector{<:Vector} @test sol.ps[p] == [1, 2, 3] -getx = getu(sys, x) -get_mix_arr = getu(sys, [x, y]) -get_mix_tuple = getu(sys, (x, y)) x_idx = variable_index.((sys,), [x[1], x[2], x[3]]) y_idx = variable_index(sys, y) -@test getx(sol) == sol[:, x_idx] -@test get_mix_arr(sol) == vcat.(sol[:, x_idx], sol[:, y_idx]) -@test get_mix_tuple(sol) == tuple.(sol[:, x_idx], sol[:, y_idx]) +x_val = vcat.(getindex.((sol,), x_idx, :)...) +y_val = sol[y_idx, :] +obs_val = sol[x[1] + y] + +# checking inference for mixed-type arrays will always fail +for (sym, val, check_inference) in [ + (x, x_val, true), + (y, y_val, true), + (y_idx, y_val, true), + (x_idx, x_val, true), + (x[1] + y, obs_val, true), + ([x[1], x[2]], sol[[x[1], x[2]]], true), + ([x[1], x_idx[2]], sol[[x[1], x[2]]], false), + ([x, x[1] + y], [[i, j] for (i, j) in zip(x_val, obs_val)], false), + ([x, y], [[i, j] for (i, j) in zip(x_val, y_val)], false), + ([x, y_idx], [[i, j] for (i, j) in zip(x_val, y_val)], false), + ([x, x], [[i, i] for i in x_val], true), + ([x, x_idx], [[i, i] for i in x_val], false), + ((x, y), [(i, j) for (i, j) in zip(x_val, y_val)], true), + ((x, y_idx), [(i, j) for (i, j) in zip(x_val, y_val)], true), + ((x, x), [(i, i) for i in x_val], true), + ((x, x_idx), [(i, i) for i in x_val], true), + ((x, x[1]+y), [(i, j) for (i, j) in zip(x_val, obs_val)], true), + ((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], true), + ([x, [x[1] + y, y]], [[i, [k, j]] for (i, j, k) in zip(x_val, y_val, obs_val)], false), + ((x, [x[1] + y, y], (x[1] + y, y_idx)), [(i, [k, j], (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], false), + ([x, [x[1] + y, y], (x[1] + y, y_idx)], [[i, [k, j], (k, j)] for (i, j, k) in zip(x_val, y_val, obs_val)], false), +] + if check_inference + @inferred getu(prob, sym)(sol) + end + @test getu(prob, sym)(sol) == val +end + +x_newval = [3.0, 6.0, 9.0] +y_newval = 4.0 +x_probval = prob[x] +y_probval = prob[y] + +for (sym, oldval, newval, check_inference) in [ + (x, x_probval, x_newval, true), + (y, y_probval, y_newval, true), + (x_idx, x_probval, x_newval, true), + (y_idx, y_probval, y_newval, true), + ((x, y), (x_probval, y_probval), (x_newval, y_newval), true), + ([x, y], [x_probval, y_probval], [x_newval, y_newval], false), + ((x, y_idx), (x_probval, y_probval), (x_newval, y_newval), true), + ([x, y_idx], [x_probval, y_probval], [x_newval, y_newval], false), + ((x_idx, y), (x_probval, y_probval), (x_newval, y_newval), true), + ([x_idx, y], [x_probval, y_probval], [x_newval, y_newval], false), + ([x[1:2], [y_idx, x[3]]], [x_probval[1:2], [y_probval, x_probval[3]]], [x_newval[1:2], [y_newval, x_newval[3]]], true), + ([x[1:2], (y_idx, x[3])], [x_probval[1:2], (y_probval, x_probval[3])], [x_newval[1:2], (y_newval, x_newval[3])], false), + ((x[1:2], [y_idx, x[3]]), (x_probval[1:2], [y_probval, x_probval[3]]), (x_newval[1:2], [y_newval, x_newval[3]]), true), + ((x[1:2], (y_idx, x[3])), (x_probval[1:2], (y_probval, x_probval[3])), (x_newval[1:2], (y_newval, x_newval[3])), true), +] + getter = getu(prob, sym) + setter! = setu(prob, sym) + if check_inference + @inferred getter(prob) + end + @test getter(prob) == oldval + if check_inference + @inferred setter!(prob, newval) + else + setter!(prob, newval) + end + @test getter(prob) == newval + setter!(prob, oldval) + @test getter(prob) == oldval +end + +pval = [1.0, 2.0, 3.0] +pval_new = [4.0, 5.0, 6.0] + +for (sym, oldval, newval, check_inference) in [ + (p[1], pval[1], pval_new[1], true), + (p, pval, pval_new, true), + ((p[1], p[2]), Tuple(pval[1:2]), Tuple(pval_new[1:2]), true), + ([p[1], p[2]], pval[1:2], pval_new[1:2], true), + ((p[1], p[2:3]), (pval[1], pval[2:3]), (pval_new[1], pval_new[2:3]), true), + ([p[1], p[2:3]], [pval[1], pval[2:3]], [pval_new[1], pval_new[2:3]], false), + ((p[1], (p[2],), [p[3]]), (pval[1], (pval[2],), [pval[3]]), (pval_new[1], (pval_new[2],), [pval_new[3]]), true), + ([p[1], (p[2],), [p[3]]], [pval[1], (pval[2],), [pval[3]]], [pval_new[1], (pval_new[2],), [pval_new[3]]], false), +] + getter = getp(prob, sym) + setter! = setp(prob, sym) + if check_inference + @inferred getter(prob) + end + @test getter(prob) == oldval + if check_inference + @inferred setter!(prob, newval) + else + setter!(prob, newval) + end + @test getter(prob) == newval + setter!(prob, oldval) + @test getter(prob) == oldval +end # accessing parameters @variables t x(t) From e25f46dc01b4ec325835084dbdfbe6ce0e197ecb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 24 Jan 2024 12:59:37 +0530 Subject: [PATCH 6/6] fix: activate downstream env in SymbolicIndexingInterface test group --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index c9aafe6ae..c904a6bd4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -104,6 +104,9 @@ end end if !is_APPVEYOR && (GROUP == "Downstream" || GROUP == "SymbolicIndexingInterface") + if GROUP != "Downstream" + activate_downstream_env() + end @time @safetestset "Symbol and integer based indexing of interpolated solutions" begin include("downstream/symbol_indexing.jl") end