Skip to content

Commit

Permalink
feat: add support for new symbol indexing methods in SII
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 27, 2023
1 parent 3bfb948 commit 42339fb
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 18 deletions.
26 changes: 16 additions & 10 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -450,18 +450,17 @@ function Base.getproperty(A::DEIntegrator, sym::Symbol)
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int},
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int},

Check warning on line 453 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L453

Added line #L453 was not covered by tests
CartesianIndex, Colon, BitArray,
AbstractArray{Bool}}...)
A.u[I...]
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym)

Check warning on line 459 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L459

Added line #L459 was not covered by tests
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing")

Check warning on line 463 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L463

Added line #L463 was not covered by tests
elseif is_independent_variable(A, sym)
return A.t
elseif is_observed(A, sym)
Expand All @@ -471,11 +470,11 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymboli
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ArraySymbolic, sym)
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ArraySymbolic, sym)

Check warning on line 473 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L473

Added line #L473 was not covered by tests
return A[collect(sym)]
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})

Check warning on line 477 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L477

Added line #L477 was not covered by tests
return getindex.((A,), sym)
end

Expand All @@ -484,12 +483,20 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
elsymtype = symbolic_type(eltype(sym))

if symtype != NotSymbolic()
return getindex(A, symtype, sym)
return _getindex(A, symtype, sym)

Check warning on line 486 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L486

Added line #L486 was not covered by tests
else
return getindex(A, elsymtype, sym)
return _getindex(A, elsymtype, sym)

Check warning on line 488 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L488

Added line #L488 was not covered by tests
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(A, variable_symbols(A))

Check warning on line 493 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L492-L493

Added lines #L492 - L493 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::SymbolicIndexingInterface.AllVariables)
return getindex(A, all_variable_symbols(A))

Check warning on line 497 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L496-L497

Added lines #L496 - L497 were not covered by tests
end

function observed(A::DEIntegrator, sym)
getobserved(A)(sym, A.u, A.p, A.t)
end
Expand All @@ -500,8 +507,7 @@ function Base.setindex!(A::DEIntegrator, val, sym)
if is_variable(A, sym)
A.u[variable_index(A, sym)] = val
elseif is_parameter(A, sym)
Base.depwarn("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.", :parameter_setindex)
setp(A, sym)(A, val)
error("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.")

Check warning on line 510 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L510

Added line #L510 was not covered by tests
else
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
end
Expand Down
14 changes: 10 additions & 4 deletions src/problems/problem_interface.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f
SymbolicIndexingInterface.parameter_values(prob::AbstractSciMLProblem) = prob.p

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(prob, variable_symbols(prob))

Check warning on line 5 in src/problems/problem_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/problem_interface.jl#L4-L5

Added lines #L4 - L5 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, ::SymbolicIndexingInterface.AllVariables)
return getindex(prob, all_variable_symbols(prob))

Check warning on line 9 in src/problems/problem_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/problem_interface.jl#L8-L9

Added lines #L8 - L9 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(prob.f, sym)
return prob.u0[variable_index(prob.f, sym)]
elseif is_parameter(prob.f, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.", :parameter_getindex)
return getp(prob, sym)(prob)
error("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.")

Check warning on line 17 in src/problems/problem_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/problem_interface.jl#L17

Added line #L17 was not covered by tests
elseif is_independent_variable(prob.f, sym)
return getindepsym(prob)
elseif is_observed(prob.f, sym)
Expand Down Expand Up @@ -37,8 +44,7 @@ function ___internal_setindex!(prob::AbstractSciMLProblem, val, sym)
if is_variable(prob.f, sym)
prob.u0[variable_index(prob.f, sym)] = val
elseif is_parameter(prob.f, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.", :parameter_setindex)
setp(prob, sym)(prob, val)
error("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.", :parameter_setindex)

Check warning on line 47 in src/problems/problem_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/problem_interface.jl#L47

Added line #L47 was not covered by tests
else
error("Invalid indexing of problem: $sym is not a state or parameter, it may be an observed variable.")
end
Expand Down
11 changes: 9 additions & 2 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.")

Check warning on line 76 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L76

Added line #L76 was not covered by tests
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, A.prob.p)
else
Expand All @@ -88,6 +87,14 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
end
end

Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(A, variable_symbols(A))

Check warning on line 91 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L90-L91

Added lines #L90 - L91 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, ::SymbolicIndexingInterface.AllVariables)
return getindex(A, all_variable_symbols(A))

Check warning on line 95 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L94-L95

Added lines #L94 - L95 were not covered by tests
end

function observed(A::AbstractTimeseriesSolution, sym, i::Int)
getobserved(A)(sym, A[i], A.prob.p, A.t[i])
end
Expand Down
3 changes: 2 additions & 1 deletion test/downstream/integrator_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ integrator = init(oprob, Rodas4())

@test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 2.0
@test integrator[s2] == integrator[population_model.s2] == integrator[:s2] == 1.0

@test integrator[solvedvariables] == integrator.u
@test integrator[allvariables] == integrator.u
step!(integrator, 100.0, true)

@test getp(population_model, a)(integrator) == getp(population_model, population_model.a)(integrator) == getp(population_model, :a)(integrator) == 2.0
Expand Down
3 changes: 3 additions & 0 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelingToolkit, OrdinaryDiffEq, Test
using SymbolicIndexingInterface

@parameters σ ρ β
@variables t x(t) y(t) z(t)
Expand Down Expand Up @@ -33,6 +34,8 @@ oprob = ODEProblem(sys, u0, tspan, p, jac = true)
@test oprob[x] == oprob[sys.x] == oprob[:x] == 1.0
@test oprob[y] == oprob[sys.y] == oprob[:y] == 0.0
@test oprob[z] == oprob[sys.z] == oprob[:z] == 0.0
@test oprob[solvedvariables] == [2.0, 1.0, 0.0, 0.0]
@test oprob[allvariables] == [2.0, 1.0, 0.0, 0.0]

oprob[σ] = 10.0
@test oprob[σ] == oprob[sys.σ] == oprob[] == 10.0
Expand Down
12 changes: 11 additions & 1 deletion test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@ sol = solve(prob, Rodas4())
@test_throws Any sol['a', [1, 2, 3]]

@test sol[a] isa AbstractVector
@test sol[:a] == sol[a]
@test sol[a, 1] isa Real
@test sol[:a, 1] == sol[a, 1]
@test sol[a, 1:5] isa AbstractVector
@test sol[:a, 1:5] == sol[a, 1:5]
@test sol[a, [1, 2, 3]] isa AbstractVector
@test sol[:a, [1, 2, 3]] == sol[a, [1, 2, 3]]

@test sol[:, 1] isa AbstractVector
@test sol[:, 1:2] isa AbstractDiffEqArray
Expand All @@ -68,7 +72,7 @@ sol = solve(prob, Rodas4())
@test sol[α, 3] isa Float64
@test length(sol[α, 5:10]) == 6
@test getp(prob, γ)(sol) isa Real
@test getp(prob, γ)(sol) == 2.0
@test getp(prob, γ)(sol) == getp(prob, )(sol) == 2.0
@test getp(prob, (lorenz1.σ, lorenz1.ρ))(sol) isa Tuple

@test sol[[lorenz1.x, lorenz2.x]] isa Vector{Vector{Float64}}
Expand Down Expand Up @@ -171,6 +175,12 @@ sol9 = sol(0.0:1.0:10.0, idxs = 2)
sol10 = sol(0.1, idxs = 2)
@test sol10 isa Real

sol11 = solve(prob, Rodas4(), save_idxs = [lorenz1.x, lorenz2.z, a])
@test length(sol11[:, 1]) == 3
@test sol11[lorenz1.x] == sol11[:, 1] == sol[lorenz1.x]
@test sol11[lorenz2.z] == sol11[:, 2] == sol[lorenz2.z]
@test sol11[a] == sol11[:, 3] == sol[a]

#=
using Plots
plot(sol,idxs=(lorenz2.x,lorenz2.z))
Expand Down

0 comments on commit 42339fb

Please sign in to comment.