Skip to content

Commit

Permalink
Merge pull request #571 from AayushSabharwal/as/tests
Browse files Browse the repository at this point in the history
feat: add support for new symbol indexing methods in SII
  • Loading branch information
ChrisRackauckas authored Dec 28, 2023
2 parents f17b506 + 1bdb65c commit 1b995b6
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 65 deletions.
8 changes: 4 additions & 4 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ using Zygote
using Zygote: @adjoint, pullback
import Zygote: literal_getproperty
using SciMLBase
using SciMLBase: ODESolution, sym_to_index, remake,
using SciMLBase: ODESolution, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
using SymbolicIndexingInterface: symbolic_type, NotSymbolic
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index
using RecursiveArrayTools

# This method resolves the ambiguity with the pullback defined in
Expand All @@ -34,7 +34,7 @@ end

@adjoint function getindex(VA::ODESolution, sym, j::Int)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
du, dprob = if i === nothing
getter = getobserved(VA)
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
Expand Down Expand Up @@ -96,7 +96,7 @@ end

@adjoint function getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Expand Down
2 changes: 1 addition & 1 deletion src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ end
end
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Union{ScalarSymbolic,ArraySymbolic}, s, ::Colon)
Base.@propagate_inbounds function RecursiveArrayTools._getindex(x::AbstractEnsembleSolution, ::Union{ScalarSymbolic,ArraySymbolic}, s, ::Colon)
return [xi[s] for xi in x.u]
end

Expand Down
26 changes: 16 additions & 10 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -450,18 +450,17 @@ function Base.getproperty(A::DEIntegrator, sym::Symbol)
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int},
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int},
CartesianIndex, Colon, BitArray,
AbstractArray{Bool}}...)
A.u[I...]
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing")
elseif is_independent_variable(A, sym)
return A.t
elseif is_observed(A, sym)
Expand All @@ -471,11 +470,11 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymboli
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ArraySymbolic, sym)
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ArraySymbolic, sym)
return A[collect(sym)]
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
return getindex.((A,), sym)
end

Expand All @@ -484,12 +483,20 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
elsymtype = symbolic_type(eltype(sym))

if symtype != NotSymbolic()
return getindex(A, symtype, sym)
return _getindex(A, symtype, sym)
else
return getindex(A, elsymtype, sym)
return _getindex(A, elsymtype, sym)
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(A, variable_symbols(A))
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::SymbolicIndexingInterface.AllVariables)
return getindex(A, all_variable_symbols(A))
end

function observed(A::DEIntegrator, sym)
getobserved(A)(sym, A.u, A.p, A.t)
end
Expand All @@ -500,8 +507,7 @@ function Base.setindex!(A::DEIntegrator, val, sym)
if is_variable(A, sym)
A.u[variable_index(A, sym)] = val
elseif is_parameter(A, sym)
Base.depwarn("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.", :parameter_setindex)
setp(A, sym)(A, val)
error("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.")
else
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
end
Expand Down
14 changes: 10 additions & 4 deletions src/problems/problem_interface.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f
SymbolicIndexingInterface.parameter_values(prob::AbstractSciMLProblem) = prob.p

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(prob, variable_symbols(prob))
end

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, ::SymbolicIndexingInterface.AllVariables)
return getindex(prob, all_variable_symbols(prob))
end

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(prob.f, sym)
return prob.u0[variable_index(prob.f, sym)]
elseif is_parameter(prob.f, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.", :parameter_getindex)
return getp(prob, sym)(prob)
error("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.")
elseif is_independent_variable(prob.f, sym)
return getindepsym(prob)
elseif is_observed(prob.f, sym)
Expand Down Expand Up @@ -37,8 +44,7 @@ function ___internal_setindex!(prob::AbstractSciMLProblem, val, sym)
if is_variable(prob.f, sym)
prob.u0[variable_index(prob.f, sym)] = val
elseif is_parameter(prob.f, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.", :parameter_setindex)
setp(prob, sym)(prob, val)
error("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.")
else
error("Invalid indexing of problem: $sym is not a state or parameter, it may be an observed variable.")
end
Expand Down
18 changes: 13 additions & 5 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,22 +179,30 @@ end
function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1]
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)
else
return augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1]
end
end

function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector,
continuity) where {deriv}
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)
[first(interp_sol[idx]) for idx in idxs]
[is_parameter(sol, idx) ? getp(sol, idx)(sol) : first(interp_sol[idx]) for idx in idxs]
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
return DiffEqArray(interp_sol[idxs], t, p, sol)
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)
else
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
return DiffEqArray(interp_sol[idxs], t, p, sol)
end
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
Expand Down
11 changes: 9 additions & 2 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.")
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, A.prob.p)
else
Expand All @@ -88,6 +87,14 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
end
end

Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(A, variable_symbols(A))
end

Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, ::SymbolicIndexingInterface.AllVariables)
return getindex(A, all_variable_symbols(A))
end

function observed(A::AbstractTimeseriesSolution, sym, i::Int)
getobserved(A)(sym, A[i], A.prob.p, A.t[i])
end
Expand Down
4 changes: 2 additions & 2 deletions test/downstream/ensemble_nondes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistr
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

prob = OptimizationProblem(optf, x0, lb = [-0.5, -0.5], ub = [0.5, 0.5])
ensembleprob = Optimization.EnsembleProblem(prob, 5, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2)))
ensembleprob = Optimization.EnsembleProblem(prob, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2)))

sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 5, maxiters = 5)
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
Expand All @@ -35,4 +35,4 @@ ensembleprob = EnsembleProblem(prob, [u0, u0 .+ rand(2), u0 .+ rand(2), u0 .+ ra

sol = solve(ensembleprob, EnsembleThreads(), trajectories = 4, maxiters = 100)

sol = solve(ensembleprob, EnsembleDistributed(), trajectories = 4, maxiters = 100)
sol = solve(ensembleprob, EnsembleDistributed(), trajectories = 4, maxiters = 100)
15 changes: 8 additions & 7 deletions test/downstream/integrator_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@ tspan = (0.0, 1000000.0)
oprob = ODEProblem(population_model, u0, tspan, p)
integrator = init(oprob, Rodas4())

@test_deprecated integrator[a]
@test_deprecated integrator[population_model.a]
@test_deprecated integrator[:a]
@test_throws Exception integrator[a]
@test_throws Exception integrator[population_model.a]
@test_throws Exception integrator[:a]
@test getp(oprob, a)(integrator) == getp(oprob, population_model.a)(integrator) == getp(oprob, :a)(integrator) == 2.0
@test getp(oprob, b)(integrator) == getp(oprob, population_model.b)(integrator) == getp(oprob, :b)(integrator) == 1.0
@test getp(oprob, c)(integrator) == getp(oprob, population_model.c)(integrator) == getp(oprob, :c)(integrator) == 1.0
@test getp(oprob, d)(integrator) == getp(oprob, population_model.d)(integrator) == getp(oprob, :d)(integrator) == 1.0

@test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 2.0
@test integrator[s2] == integrator[population_model.s2] == integrator[:s2] == 1.0

@test integrator[solvedvariables] == integrator.u
@test integrator[allvariables] == integrator.u
step!(integrator, 100.0, true)

@test getp(population_model, a)(integrator) == getp(population_model, population_model.a)(integrator) == getp(population_model, :a)(integrator) == 2.0
Expand Down Expand Up @@ -299,6 +300,6 @@ eqs = [collect(D.(x) .~ x)
D(y) ~ norm(x) * y - x[1]]
@named sys = ODESystem(eqs, t, [sts...;], [ps...;])
prob = ODEProblem(sys, [], (0, 1.0))
@test_broken local integrator = init(prob, Tsit5())
@test_broken integrator[x] isa Vector{<:Vector}
@test_broken integrator[@nonamespace sys.x] isa Vector{<:Vector}
integrator = init(prob, Tsit5())
@test integrator[x] isa Vector{Float64}
@test integrator[@nonamespace sys.x] isa Vector{Float64}
58 changes: 39 additions & 19 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelingToolkit, OrdinaryDiffEq, Test
using SymbolicIndexingInterface

@parameters σ ρ β
@variables t x(t) y(t) z(t)
Expand Down Expand Up @@ -26,20 +27,37 @@ tspan = (0.0, 100.0)
# ODEProblem.
oprob = ODEProblem(sys, u0, tspan, p, jac = true)

@test oprob[σ] == oprob[sys.σ] == oprob[] == 28.0
@test oprob[ρ] == oprob[sys.ρ] == oprob[] == 10.0
@test oprob[β] == oprob[sys.β] == oprob[] == 8 / 3
@test_throws Exception oprob[σ]
@test_throws Exception oprob[sys.σ]
@test_throws Exception oprob[]
getσ1 = getp(sys, σ)
getσ2 = getp(sys, sys.σ)
getσ3 = getp(sys, )
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 28.0
getρ1 = getp(sys, ρ)
getρ2 = getp(sys, sys.ρ)
getρ3 = getp(sys, )
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 10.0
getβ1 = getp(sys, β)
getβ2 = getp(sys, sys.β)
getβ3 = getp(sys, )
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 8 / 3

@test oprob[x] == oprob[sys.x] == oprob[:x] == 1.0
@test oprob[y] == oprob[sys.y] == oprob[:y] == 0.0
@test oprob[z] == oprob[sys.z] == oprob[:z] == 0.0

oprob[σ] = 10.0
@test oprob[σ] == oprob[sys.σ] == oprob[] == 10.0
oprob[sys.ρ] = 20.0
@test oprob[ρ] == oprob[sys.ρ] == oprob[] == 20.0
oprob[σ] = 30.0
@test oprob[σ] == oprob[sys.σ] == oprob[] == 30.0
@test oprob[solvedvariables] == oprob[variable_symbols(sys)]
@test oprob[allvariables] == oprob[all_variable_symbols(sys)]

setσ = setp(sys, σ)
setσ(oprob, 10.0)
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 10.0
setρ = setp(sys, sys.ρ)
setρ(oprob, 20.0)
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 20.0
setβ = setp(sys, )
setβ(oprob, 30.0)
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 30.0

oprob[x] = 10.0
@test oprob[x] == oprob[sys.x] == oprob[:x] == 10.0
Expand All @@ -56,20 +74,22 @@ noiseeqs = [0.1 * x,
sprob = SDEProblem(noise_sys, u0, (0.0, 100.0), p)
u0

@test sprob[σ] == sprob[noise_sys.σ] == sprob[] == 28.0
@test sprob[ρ] == sprob[noise_sys.ρ] == sprob[] == 10.0
@test sprob[β] == sprob[noise_sys.β] == sprob[] == 8 / 3
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 28.0
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 10.0
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 8 / 3

@test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 1.0
@test sprob[y] == sprob[noise_sys.y] == sprob[:y] == 0.0
@test sprob[z] == sprob[noise_sys.z] == sprob[:z] == 0.0

sprob[σ] = 10.0
@test sprob[σ] == sprob[noise_sys.σ] == sprob[] == 10.0
sprob[noise_sys.ρ] = 20.0
@test sprob[ρ] == sprob[noise_sys.ρ] == sprob[] == 20.0
sprob[σ] = 30.0
@test sprob[σ] == sprob[noise_sys.σ] == sprob[] == 30.0
setσ(sprob, 10.0)
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 10.0
setρ(sprob, 20.0)
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 20.0
setp(noise_sys, noise_sys.ρ)(sprob, 25.0)
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 25.0
setβ(sprob, 30.0)
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 30.0

sprob[x] = 10.0
@test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 10.0
Expand Down
12 changes: 6 additions & 6 deletions test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ sol = solve(oprob, Rodas4())
@test sol[s1] == sol[population_model.s1] == sol[:s1]
@test sol[s2] == sol[population_model.s2] == sol[:s2]
@test sol[s1][end] 1.0
@test_deprecated sol[a]
@test_deprecated sol[population_model.a]
@test_deprecated sol[:a]
@test_throws Exception sol[a]
@test_throws Exception sol[population_model.a]
@test_throws Exception sol[:a]

# Tests on SDEProblem
noiseeqs = [0.1 * s1,
Expand All @@ -34,9 +34,9 @@ sol = solve(sprob, ImplicitEM())

@test sol[s1] == sol[noisy_population_model.s1] == sol[:s1]
@test sol[s2] == sol[noisy_population_model.s2] == sol[:s2]
@test_deprecated sol[a]
@test_deprecated sol[noisy_population_model.a]
@test_deprecated sol[:a]
@test_throws Exception sol[a]
@test_throws Exception sol[noisy_population_model.a]
@test_throws Exception sol[:a]
### Tests on layered model (some things should not work). ###

@parameters t σ ρ β
Expand Down
Loading

0 comments on commit 1b995b6

Please sign in to comment.