Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use getu/setu for all indexing #721

Merged
merged 5 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
path: downstream
- name: Load this and run the downstream tests
shell: julia --color=yes --project=downstream {0}
shell: julia --color=yes --project=downstream --depwarn=yes {0}
run: |
using Pkg
try
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
- "Python"
uses: "SciML/.github/.github/workflows/tests.yml@v1"
with:
julia-runtest-depwarn: "yes"
group: "${{ matrix.group }}"
julia-version: "${{ matrix.version }}"
secrets: "inherit"
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Reexport = "1"
RuntimeGeneratedFunctions = "0.5.12"
SciMLOperators = "0.3.7"
SciMLStructures = "1.1"
StableRNGs = "1.0"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
Expand All @@ -106,11 +107,12 @@ PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]
75 changes: 21 additions & 54 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -465,46 +465,20 @@ function Base.getproperty(A::DEIntegrator, sym::Symbol)
end
end

Base.@propagate_inbounds function _getindex(A::DEIntegrator,
::NotSymbolic,
I::Union{Int, AbstractArray{Int},
CartesianIndex, Colon, BitArray,
AbstractArray{Bool}}...)
A.u[I...]
end

Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing")
elseif is_independent_variable(A, sym)
return A.t
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, A.p, A.t)
else
error("Tried to index integrator with a Symbol that was not found in the system.")
Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
end

Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ArraySymbolic, sym)
return A[collect(sym)]
end

Base.@propagate_inbounds function _getindex(
A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray})
return getindex.((A,), sym)
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))

if symtype != NotSymbolic()
return _getindex(A, symtype, sym)
else
return _getindex(A, elsymtype, sym)
Base.@propagate_inbounds function Base.getindex(
A::DEIntegrator, sym::Union{AbstractArray, Tuple})
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
Expand All @@ -522,25 +496,18 @@ function observed(A::DEIntegrator, sym)
end

function Base.setindex!(A::DEIntegrator, val, sym)
has_sys(A.f) ||
error("Invalid indexing of integrator: Integrator does not support indexing without a system")
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(A, sym)
set_state!(A, val, variable_index(A, sym))
elseif is_parameter(A, sym)
error("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.")
else
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
end
return A
elseif symbolic_type(sym) == ArraySymbolic()
setindex!.((A,), val, collect(sym))
return A
else
sym isa AbstractArray || error("Invalid indexing of integrator")
setindex!.((A,), val, sym)
return A
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.")
end
setu(A, sym)(A, val)
end

function Base.setindex!(A::DEIntegrator, val, sym::Union{AbstractArray, Tuple})
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.")
end
setu(A, sym)(A, val)
end

### Integrator traits
Expand Down
67 changes: 27 additions & 40 deletions src/problems/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,51 +38,38 @@ Base.@propagate_inbounds function Base.getindex(
return getindex(prob, all_variable_symbols(prob))
end

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(prob, sym)
return state_values(prob, variable_index(prob, sym))
elseif is_parameter(prob, sym)
error("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.")
elseif is_independent_variable(prob, sym)
return current_time(prob)
elseif is_observed(prob, sym)
obs = SymbolicIndexingInterface.observed(prob, sym)
if is_time_dependent(prob)
return obs(state_values(prob), parameter_values(prob), current_time(prob))
else
return obs(state_values(prob), parameter_values(prob))
end
else
error("Invalid indexing of problem: $sym is not a state, parameter, or independent variable")
end
elseif symbolic_type(sym) == ArraySymbolic()
return map(s -> prob[s], collect(sym))
else
sym isa AbstractArray || error("Invalid indexing of problem")
return map(s -> prob[s], sym)
Base.@propagate_inbounds function Base.getindex(A::AbstractSciMLProblem, sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
A::AbstractSciMLProblem, sym::Union{AbstractArray, Tuple})
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
end

function Base.setindex!(prob::AbstractSciMLProblem, args...; kwargs...)
___internal_setindex!(prob::AbstractSciMLProblem, args...; kwargs...)
end
function ___internal_setindex!(prob::AbstractSciMLProblem, val, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(prob, sym)
set_state!(prob, val, variable_index(prob, sym))
elseif is_parameter(prob, sym)
error("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.")
else
error("Invalid indexing of problem: $sym is not a state or parameter, it may be an observed variable.")
end
return prob
elseif symbolic_type(sym) == ArraySymbolic()
setindex!.((prob,), val, collect(sym))
return prob
else
sym isa AbstractArray || error("Invalid indexing of problem")
setindex!.((prob,), val, sym)
return prob

function ___internal_setindex!(A::AbstractSciMLProblem, val, sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.")
end
return setu(A, sym)(A, val)
end

function ___internal_setindex!(
A::AbstractSciMLProblem, val, sym::Union{AbstractArray, Tuple})
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.")
end
return setu(A, sym)(A, val)
end
8 changes: 7 additions & 1 deletion src/solutions/optimization_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,13 @@ function reinit!(cache::SciMLBase.AbstractOptimizationCache; p = missing,
return cache
end

SymbolicIndexingInterface.parameter_values(x::AbstractOptimizationCache) = x.p
function SymbolicIndexingInterface.parameter_values(x::AbstractOptimizationCache)
if has_reinit(x)
x.reinit_cache.p
else
x.p
end
end
SymbolicIndexingInterface.symbolic_container(x::AbstractOptimizationCache) = x.f

get_p(sol::OptimizationSolution) = sol.cache.p
Expand Down
27 changes: 12 additions & 15 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,19 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, :
end

Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.")
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, parameter_values(A))
else
error("Tried to index solution with a Symbol that was not found in the system.")
end
elseif symbolic_type(sym) == ArraySymbolic()
return A[collect(sym)]
else
sym isa AbstractArray || error("Invalid indexing of solution")
return getindex.((A,), sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
A::AbstractNoTimeSolution, sym::Union{AbstractArray, Tuple})
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
Expand Down
4 changes: 4 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Expand All @@ -34,6 +37,7 @@ RecursiveArrayTools = "3"
SciMLBase = "2"
SciMLSensitivity = "7.11"
SciMLStructures = "1.1"
SteadyStateDiffEq = "2.2"
Sundials = "4.11"
SymbolicIndexingInterface = "0.3"
SymbolicUtils = "<1.6, 2"
Expand Down
Loading
Loading