Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for new symbol indexing methods in SII #571

Merged
merged 6 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ using Zygote
using Zygote: @adjoint, pullback
import Zygote: literal_getproperty
using SciMLBase
using SciMLBase: ODESolution, sym_to_index, remake,
using SciMLBase: ODESolution, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
using SymbolicIndexingInterface: symbolic_type, NotSymbolic
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index
using RecursiveArrayTools

# This method resolves the ambiguity with the pullback defined in
Expand All @@ -34,7 +34,7 @@ end

@adjoint function getindex(VA::ODESolution, sym, j::Int)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
du, dprob = if i === nothing
getter = getobserved(VA)
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
Expand Down Expand Up @@ -96,7 +96,7 @@ end

@adjoint function getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Expand Down
2 changes: 1 addition & 1 deletion src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ end
end
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Union{ScalarSymbolic,ArraySymbolic}, s, ::Colon)
Base.@propagate_inbounds function RecursiveArrayTools._getindex(x::AbstractEnsembleSolution, ::Union{ScalarSymbolic,ArraySymbolic}, s, ::Colon)
return [xi[s] for xi in x.u]
end

Expand Down
26 changes: 16 additions & 10 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -450,18 +450,17 @@ function Base.getproperty(A::DEIntegrator, sym::Symbol)
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int},
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int},
CartesianIndex, Colon, BitArray,
AbstractArray{Bool}}...)
A.u[I...]
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing")
elseif is_independent_variable(A, sym)
return A.t
elseif is_observed(A, sym)
Expand All @@ -471,11 +470,11 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymboli
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ArraySymbolic, sym)
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ArraySymbolic, sym)
return A[collect(sym)]
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
return getindex.((A,), sym)
end

Expand All @@ -484,12 +483,20 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
elsymtype = symbolic_type(eltype(sym))

if symtype != NotSymbolic()
return getindex(A, symtype, sym)
return _getindex(A, symtype, sym)
else
return getindex(A, elsymtype, sym)
return _getindex(A, elsymtype, sym)
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(A, variable_symbols(A))
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::SymbolicIndexingInterface.AllVariables)
return getindex(A, all_variable_symbols(A))
end

function observed(A::DEIntegrator, sym)
getobserved(A)(sym, A.u, A.p, A.t)
end
Expand All @@ -500,8 +507,7 @@ function Base.setindex!(A::DEIntegrator, val, sym)
if is_variable(A, sym)
A.u[variable_index(A, sym)] = val
elseif is_parameter(A, sym)
Base.depwarn("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.", :parameter_setindex)
setp(A, sym)(A, val)
error("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.")
else
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
end
Expand Down
14 changes: 10 additions & 4 deletions src/problems/problem_interface.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f
SymbolicIndexingInterface.parameter_values(prob::AbstractSciMLProblem) = prob.p

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(prob, variable_symbols(prob))
end

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, ::SymbolicIndexingInterface.AllVariables)
return getindex(prob, all_variable_symbols(prob))
end

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(prob.f, sym)
return prob.u0[variable_index(prob.f, sym)]
elseif is_parameter(prob.f, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.", :parameter_getindex)
return getp(prob, sym)(prob)
error("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.")
elseif is_independent_variable(prob.f, sym)
return getindepsym(prob)
elseif is_observed(prob.f, sym)
Expand Down Expand Up @@ -37,8 +44,7 @@ function ___internal_setindex!(prob::AbstractSciMLProblem, val, sym)
if is_variable(prob.f, sym)
prob.u0[variable_index(prob.f, sym)] = val
elseif is_parameter(prob.f, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.", :parameter_setindex)
setp(prob, sym)(prob, val)
error("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.")
else
error("Invalid indexing of problem: $sym is not a state or parameter, it may be an observed variable.")
end
Expand Down
18 changes: 13 additions & 5 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,22 +179,30 @@ end
function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1]
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)
else
return augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1]
end
end

function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector,
continuity) where {deriv}
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)
[first(interp_sol[idx]) for idx in idxs]
[is_parameter(sol, idx) ? getp(sol, idx)(sol) : first(interp_sol[idx]) for idx in idxs]
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
return DiffEqArray(interp_sol[idxs], t, p, sol)
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)
else
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
return DiffEqArray(interp_sol[idxs], t, p, sol)
end
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
Expand Down
11 changes: 9 additions & 2 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.")
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, A.prob.p)
else
Expand All @@ -88,6 +87,14 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
end
end

Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(A, variable_symbols(A))
end

Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, ::SymbolicIndexingInterface.AllVariables)
return getindex(A, all_variable_symbols(A))
end

function observed(A::AbstractTimeseriesSolution, sym, i::Int)
getobserved(A)(sym, A[i], A.prob.p, A.t[i])
end
Expand Down
4 changes: 2 additions & 2 deletions test/downstream/ensemble_nondes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistr
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

prob = OptimizationProblem(optf, x0, lb = [-0.5, -0.5], ub = [0.5, 0.5])
ensembleprob = Optimization.EnsembleProblem(prob, 5, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2)))
ensembleprob = Optimization.EnsembleProblem(prob, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2)))

sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 5, maxiters = 5)
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
Expand All @@ -35,4 +35,4 @@ ensembleprob = EnsembleProblem(prob, [u0, u0 .+ rand(2), u0 .+ rand(2), u0 .+ ra

sol = solve(ensembleprob, EnsembleThreads(), trajectories = 4, maxiters = 100)

sol = solve(ensembleprob, EnsembleDistributed(), trajectories = 4, maxiters = 100)
sol = solve(ensembleprob, EnsembleDistributed(), trajectories = 4, maxiters = 100)
15 changes: 8 additions & 7 deletions test/downstream/integrator_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@ tspan = (0.0, 1000000.0)
oprob = ODEProblem(population_model, u0, tspan, p)
integrator = init(oprob, Rodas4())

@test_deprecated integrator[a]
@test_deprecated integrator[population_model.a]
@test_deprecated integrator[:a]
@test_throws Exception integrator[a]
@test_throws Exception integrator[population_model.a]
@test_throws Exception integrator[:a]
@test getp(oprob, a)(integrator) == getp(oprob, population_model.a)(integrator) == getp(oprob, :a)(integrator) == 2.0
@test getp(oprob, b)(integrator) == getp(oprob, population_model.b)(integrator) == getp(oprob, :b)(integrator) == 1.0
@test getp(oprob, c)(integrator) == getp(oprob, population_model.c)(integrator) == getp(oprob, :c)(integrator) == 1.0
@test getp(oprob, d)(integrator) == getp(oprob, population_model.d)(integrator) == getp(oprob, :d)(integrator) == 1.0

@test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 2.0
@test integrator[s2] == integrator[population_model.s2] == integrator[:s2] == 1.0

@test integrator[solvedvariables] == integrator.u
@test integrator[allvariables] == integrator.u
step!(integrator, 100.0, true)

@test getp(population_model, a)(integrator) == getp(population_model, population_model.a)(integrator) == getp(population_model, :a)(integrator) == 2.0
Expand Down Expand Up @@ -299,6 +300,6 @@ eqs = [collect(D.(x) .~ x)
D(y) ~ norm(x) * y - x[1]]
@named sys = ODESystem(eqs, t, [sts...;], [ps...;])
prob = ODEProblem(sys, [], (0, 1.0))
@test_broken local integrator = init(prob, Tsit5())
@test_broken integrator[x] isa Vector{<:Vector}
@test_broken integrator[@nonamespace sys.x] isa Vector{<:Vector}
integrator = init(prob, Tsit5())
@test integrator[x] isa Vector{Float64}
@test integrator[@nonamespace sys.x] isa Vector{Float64}
58 changes: 39 additions & 19 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelingToolkit, OrdinaryDiffEq, Test
using SymbolicIndexingInterface

@parameters σ ρ β
@variables t x(t) y(t) z(t)
Expand Down Expand Up @@ -26,20 +27,37 @@ tspan = (0.0, 100.0)
# ODEProblem.
oprob = ODEProblem(sys, u0, tspan, p, jac = true)

@test oprob[σ] == oprob[sys.σ] == oprob[:σ] == 28.0
@test oprob[ρ] == oprob[sys.ρ] == oprob[:ρ] == 10.0
@test oprob[β] == oprob[sys.β] == oprob[:β] == 8 / 3
@test_throws Exception oprob[σ]
@test_throws Exception oprob[sys.σ]
@test_throws Exception oprob[:σ]
getσ1 = getp(sys, σ)
getσ2 = getp(sys, sys.σ)
getσ3 = getp(sys, :σ)
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 28.0
getρ1 = getp(sys, ρ)
getρ2 = getp(sys, sys.ρ)
getρ3 = getp(sys, :ρ)
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 10.0
getβ1 = getp(sys, β)
getβ2 = getp(sys, sys.β)
getβ3 = getp(sys, :β)
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 8 / 3

@test oprob[x] == oprob[sys.x] == oprob[:x] == 1.0
@test oprob[y] == oprob[sys.y] == oprob[:y] == 0.0
@test oprob[z] == oprob[sys.z] == oprob[:z] == 0.0

oprob[σ] = 10.0
@test oprob[σ] == oprob[sys.σ] == oprob[:σ] == 10.0
oprob[sys.ρ] = 20.0
@test oprob[ρ] == oprob[sys.ρ] == oprob[:ρ] == 20.0
oprob[σ] = 30.0
@test oprob[σ] == oprob[sys.σ] == oprob[:σ] == 30.0
@test oprob[solvedvariables] == oprob[variable_symbols(sys)]
@test oprob[allvariables] == oprob[all_variable_symbols(sys)]

setσ = setp(sys, σ)
setσ(oprob, 10.0)
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 10.0
setρ = setp(sys, sys.ρ)
setρ(oprob, 20.0)
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 20.0
setβ = setp(sys, :β)
setβ(oprob, 30.0)
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 30.0

oprob[x] = 10.0
@test oprob[x] == oprob[sys.x] == oprob[:x] == 10.0
Expand All @@ -56,20 +74,22 @@ noiseeqs = [0.1 * x,
sprob = SDEProblem(noise_sys, u0, (0.0, 100.0), p)
u0

@test sprob[σ] == sprob[noise_sys.σ] == sprob[:σ] == 28.0
@test sprob[ρ] == sprob[noise_sys.ρ] == sprob[:ρ] == 10.0
@test sprob[β] == sprob[noise_sys.β] == sprob[:β] == 8 / 3
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 28.0
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 10.0
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 8 / 3

@test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 1.0
@test sprob[y] == sprob[noise_sys.y] == sprob[:y] == 0.0
@test sprob[z] == sprob[noise_sys.z] == sprob[:z] == 0.0

sprob[σ] = 10.0
@test sprob[σ] == sprob[noise_sys.σ] == sprob[:σ] == 10.0
sprob[noise_sys.ρ] = 20.0
@test sprob[ρ] == sprob[noise_sys.ρ] == sprob[:ρ] == 20.0
sprob[σ] = 30.0
@test sprob[σ] == sprob[noise_sys.σ] == sprob[:σ] == 30.0
setσ(sprob, 10.0)
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 10.0
setρ(sprob, 20.0)
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 20.0
setp(noise_sys, noise_sys.ρ)(sprob, 25.0)
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 25.0
setβ(sprob, 30.0)
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 30.0

sprob[x] = 10.0
@test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 10.0
Expand Down
12 changes: 6 additions & 6 deletions test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ sol = solve(oprob, Rodas4())
@test sol[s1] == sol[population_model.s1] == sol[:s1]
@test sol[s2] == sol[population_model.s2] == sol[:s2]
@test sol[s1][end] ≈ 1.0
@test_deprecated sol[a]
@test_deprecated sol[population_model.a]
@test_deprecated sol[:a]
@test_throws Exception sol[a]
@test_throws Exception sol[population_model.a]
@test_throws Exception sol[:a]

# Tests on SDEProblem
noiseeqs = [0.1 * s1,
Expand All @@ -34,9 +34,9 @@ sol = solve(sprob, ImplicitEM())

@test sol[s1] == sol[noisy_population_model.s1] == sol[:s1]
@test sol[s2] == sol[noisy_population_model.s2] == sol[:s2]
@test_deprecated sol[a]
@test_deprecated sol[noisy_population_model.a]
@test_deprecated sol[:a]
@test_throws Exception sol[a]
@test_throws Exception sol[noisy_population_model.a]
@test_throws Exception sol[:a]
### Tests on layered model (some things should not work). ###

@parameters t σ ρ β
Expand Down
Loading