Skip to content

Commit

Permalink
Merge pull request #532 from AayushSabharwal/as/indexing-rework
Browse files Browse the repository at this point in the history
feat: indexing rework with new SymbolicIndexingInterface
  • Loading branch information
ChrisRackauckas authored Dec 12, 2023
2 parents 9fdb9c8 + bbba9c6 commit dc1c83d
Show file tree
Hide file tree
Showing 18 changed files with 625 additions and 891 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ PyCall = "1.96"
PythonCall = "0.9.15"
RCall = "0.13.18"
RecipesBase = "1.0"
RecursiveArrayTools = "2.38"
RecursiveArrayTools = "3.0"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5"
SciMLOperators = "0.3.7"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.9"
SymbolicIndexingInterface = "0.2"
SymbolicIndexingInterface = "0.3"
Tables = "1.11"
TruncatedStacktraces = "1.4"
QuasiMonteCarlo = "0.2.19, 0.3"
Expand Down
4 changes: 2 additions & 2 deletions ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{
sym,
j::Integer)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
if i === nothing
getter = getobserved(VA)
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
Expand Down Expand Up @@ -65,7 +65,7 @@ end

function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
if i === nothing
throw(error("AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Expand Down
9 changes: 5 additions & 4 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ using Zygote
using Zygote: @adjoint, pullback
import Zygote: literal_getproperty
using SciMLBase
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake,
using SciMLBase: ODESolution, sym_to_index, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
using SymbolicIndexingInterface: symbolic_type, NotSymbolic

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand All @@ -32,7 +33,7 @@ end

@adjoint function getindex(VA::ODESolution, sym, j::Int)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
du, dprob = if i === nothing
getter = getobserved(VA)
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
Expand Down Expand Up @@ -81,7 +82,7 @@ end
for (x, j) in zip(VA.u, 1:length(VA))]
(Δ′, nothing)
end
VA[i], ODESolution_getindex_pullback
VA[:, i], ODESolution_getindex_pullback
end

@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution,
Expand All @@ -91,7 +92,7 @@ end

@adjoint function getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Expand Down
20 changes: 5 additions & 15 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ end # Vector of some type which is not an array
function EnsembleSolution(sim::T, elapsedTime,
converged, stats=nothing) where {T <: AbstractVector{T2}
} where {T2 <:
AbstractArray}
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1, typeof(sim)}(
sim,
Union{AbstractArray,RecursiveArrayTools.AbstractVectorOfArray}}
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1,
typeof(sim)}(sim,
elapsedTime,
converged,
stats)
Expand Down Expand Up @@ -209,18 +209,8 @@ end
end
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return [xi[s] for xi in x]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution,
::Colon,
args::Colon...)
return invoke(getindex,
Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...},
x,
:,
args...)
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Union{ScalarSymbolic,ArraySymbolic}, s, ::Colon)
return [xi[s] for xi in x.u]
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
Expand Down
184 changes: 87 additions & 97 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,13 @@ function set_u!(integrator::DEIntegrator, sym, val)
# So any error checking happens to ensure we actually _can_ set state
set_u!(integrator, integrator.u)

if !issymbollike(sym)
if symbolic_type(sym) == NotSymbolic()
error("sym must be a symbol")
end
i = sym_to_index(sym, integrator)
i = variable_index(integrator, sym)

if isnothing(i)
error("sym is not a state variable")
error("$sym is not a state variable")
end

integrator.u[i] = val
Expand Down Expand Up @@ -385,27 +385,27 @@ end

### Indexing
function getsyms(integrator::DEIntegrator)
if has_syms(integrator.f)
return integrator.f.syms
else
return keys(integrator.u[1])
syms = variable_symbols(integrator)
if isempty(syms)
syms = keys(integrator.u)
end
return syms
end

function getindepsym(integrator::DEIntegrator)
if has_indepsym(integrator.f)
return integrator.f.indepsym
else
syms = independent_variable_symbols(integrator)
if isempty(syms)
return nothing
end
return syms
end

function getparamsyms(integrator::DEIntegrator)
if has_paramsyms(integrator.f)
return integrator.f.paramsyms
else
psyms = parameter_symbols(integrator)
if isempty(psyms)
return nothing
end
return psyms
end

function getobserved(integrator::DEIntegrator)
Expand All @@ -417,58 +417,76 @@ function getobserved(integrator::DEIntegrator)
end

function sym_to_index(sym, integrator::DEIntegrator)
if has_sys(integrator.f) && is_state_sym(integrator.f.sys, sym)
return state_sym_to_index(integrator.f.sys, sym)
idx = variable_index(integrator, sym)
if idx === nothing
idx = findfirst(isequal(sym), keys(integrator.u))
end
return idx
end

# SymbolicIndexingInterface
SymbolicIndexingInterface.symbolic_container(A::DEIntegrator) = A.f
SymbolicIndexingInterface.parameter_values(A::DEIntegrator) = A.p

function SymbolicIndexingInterface.is_observed(A::DEIntegrator, sym)
return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic()
end

function SymbolicIndexingInterface.observed(A::DEIntegrator, sym)
(u, p, t) -> getobserved(A)(sym, u, p, t)
end

SymbolicIndexingInterface.is_time_dependent(::DEIntegrator) = true

# TODO make this nontrivial once dynamic state selection works
SymbolicIndexingInterface.constant_structure(::DEIntegrator) = true

function Base.getproperty(A::DEIntegrator, sym::Symbol)
if sym === :destats && hasfield(typeof(A), :stats)
@warn "destats has been deprecated for stats"
getfield(A, :stats)
else
return sym_to_index(sym, getsyms(integrator))
return getfield(A, sym)
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator,
I::Union{Int, AbstractArray{Int},
Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int},
CartesianIndex, Colon, BitArray,
AbstractArray{Bool}}...)
RecursiveArrayTools.VectorOfArray(A.u)[I...]
A.u[I...]
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
elseif is_independent_variable(A, sym)
return A.t
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, A.p, A.t)
else
error("Tried to index integrator with a Symbol that was not found in the system.")
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ArraySymbolic, sym)
return A[collect(sym)]
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
return getindex.((A,), sym)
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
if issymbollike(sym)
if sym isa AbstractArray
return A[collect(sym)]
end
i = sym_to_index(sym, A)
elseif all(issymbollike, sym)
return getindex.((A,), sym)
else
i = sym
end
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))

if i === nothing
if issymbollike(sym)
if has_sys(A.f) && is_indep_sym(A.f.sys, sym) ||
Symbol(sym) == getindepsym(A)
return A.t
elseif has_sys(A.f) && is_param_sym(A.f.sys, sym)
return A.p[param_sym_to_index(A.f.sys, sym)]
elseif has_paramsyms(A.f) && Symbol(sym) in getparamsyms(A)
return A.p[findfirst(x -> isequal(x, Symbol(sym)), getparamsyms(A))]
elseif (sym isa Symbol) && has_sys(A.f) && hasproperty(A.f.sys, sym) # Handles input like :X (where X is a state).
return observed(A, getproperty(A.f.sys, sym))
elseif has_sys(A.f) && (count('', String(Symbol(sym))) == 1) &&
(count(isequal(Symbol(sym)),
Symbol.(A.f.sys.name, :₊, getparamsyms(A))) == 1) # Handles input like sys.X (where X is a parameter).
return A.p[findfirst(isequal(Symbol(sym)),
Symbol.(A.f.sys.name, :₊, getparamsyms(A)))]
else
return observed(A, sym)
end
else
observed(A, sym)
end
elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer}
A[i]
if symtype != NotSymbolic()
return getindex(A, symtype, sym)
else
error("Invalid indexing of integrator")
return getindex(A, elsymtype, sym)
end
end

Expand All @@ -477,52 +495,24 @@ function observed(A::DEIntegrator, sym)
end

function Base.setindex!(A::DEIntegrator, val, sym)
if has_sys(A.f)
if issymbollike(sym)
params = getparamsyms(A)
s = Symbol.(states(A.f.sys))
params = Symbol.(params)

i = findfirst(isequal(Symbol(sym)), s)
if !isnothing(i)
A.u[i] = val
return A
elseif sym isa Symbol # Handles input like :X.
s_f = Symbol.(getproperty.(states(A.f.sys), :f))
if count(isequal(Symbol(sym)), s_f) == 1
i = findfirst(isequal(sym), s_f)
A.u[i] = val
return A
elseif count(isequal(Symbol(sym)), s_f) > 1
error("The input symbol $(sym) occurs several times among integrator states. Please avoid use Symbol form (:$(sym)).")
end
elseif count('', String(Symbol(sym))) == 1 # Handles input like sys.X.
s_names = Symbol.(A.f.sys.name, :₊, s)
if count(isequal(Symbol(sym)), s_names) == 1
i = findfirst(isequal(Symbol(sym)), s_names)
A.u[i] = val
return A
end
end

i = findfirst(isequal(Symbol(sym)), params)
if !isnothing(i)
A.p[i] = val
return A
elseif count('', String(Symbol(sym))) == 1 # Handles input like sys.X.
p_names = Symbol.(A.f.sys.name, :₊, params)
if count(isequal(Symbol(sym)), p_names) == 1
i = findfirst(isequal(Symbol(sym)), p_names)
A.p[i] = val
return A
end
end
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
has_sys(A.f) || error("Invalid indexing of integrator: Integrator does not support indexing without a system")
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(A, sym)
A.u[variable_index(A, sym)] = val
elseif is_parameter(A, sym)
Base.depwarn("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.", :parameter_setindex)
setp(A, sym)(A, val)
else
error("Invalid indexing of integrator: $sym is not a symbol")
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
end
return A
elseif symbolic_type(sym) == ArraySymbolic()
setindex!.((A,), val, collect(sym))
return A
else
error("Invalid indexing of integrator: Integrator does not support indexing without a system")
sym isa AbstractArray || error("Invalid indexing of integrator")
setindex!.((A,), val, sym)
return A
end
end

Expand Down
Loading

0 comments on commit dc1c83d

Please sign in to comment.