Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: indexing rework with new SymbolicIndexingInterface #532

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ PyCall = "1.96"
PythonCall = "0.9.15"
RCall = "0.13.18"
RecipesBase = "1.0"
RecursiveArrayTools = "2.38"
RecursiveArrayTools = "3.0"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5"
SciMLOperators = "0.3.7"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.9"
SymbolicIndexingInterface = "0.2"
SymbolicIndexingInterface = "0.3"
Tables = "1.11"
TruncatedStacktraces = "1.4"
QuasiMonteCarlo = "0.2.19, 0.3"
Expand Down
4 changes: 2 additions & 2 deletions ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{
sym,
j::Integer)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
if i === nothing
getter = getobserved(VA)
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
Expand Down Expand Up @@ -65,7 +65,7 @@ end

function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
if i === nothing
throw(error("AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Expand Down
9 changes: 5 additions & 4 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ using Zygote
using Zygote: @adjoint, pullback
import Zygote: literal_getproperty
using SciMLBase
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake,
using SciMLBase: ODESolution, sym_to_index, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
using SymbolicIndexingInterface: symbolic_type, NotSymbolic

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand All @@ -32,7 +33,7 @@ end

@adjoint function getindex(VA::ODESolution, sym, j::Int)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
du, dprob = if i === nothing
getter = getobserved(VA)
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
Expand Down Expand Up @@ -81,7 +82,7 @@ end
for (x, j) in zip(VA.u, 1:length(VA))]
(Δ′, nothing)
end
VA[i], ODESolution_getindex_pullback
VA[:, i], ODESolution_getindex_pullback
end

@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution,
Expand All @@ -91,7 +92,7 @@ end

@adjoint function getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Expand Down
20 changes: 5 additions & 15 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ end # Vector of some type which is not an array
function EnsembleSolution(sim::T, elapsedTime,
converged, stats=nothing) where {T <: AbstractVector{T2}
} where {T2 <:
AbstractArray}
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1, typeof(sim)}(
sim,
Union{AbstractArray,RecursiveArrayTools.AbstractVectorOfArray}}
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1,
typeof(sim)}(sim,
elapsedTime,
converged,
stats)
Expand Down Expand Up @@ -209,18 +209,8 @@ end
end
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return [xi[s] for xi in x]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am very happy you were able to get rid of this.

::Colon,
args::Colon...)
return invoke(getindex,
Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...},
x,
:,
args...)
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Union{ScalarSymbolic,ArraySymbolic}, s, ::Colon)
return [xi[s] for xi in x.u]
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
Expand Down
184 changes: 87 additions & 97 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,13 @@ function set_u!(integrator::DEIntegrator, sym, val)
# So any error checking happens to ensure we actually _can_ set state
set_u!(integrator, integrator.u)

if !issymbollike(sym)
if symbolic_type(sym) == NotSymbolic()
error("sym must be a symbol")
end
i = sym_to_index(sym, integrator)
i = variable_index(integrator, sym)

if isnothing(i)
error("sym is not a state variable")
error("$sym is not a state variable")
end

integrator.u[i] = val
Expand Down Expand Up @@ -385,27 +385,27 @@ end

### Indexing
function getsyms(integrator::DEIntegrator)
if has_syms(integrator.f)
return integrator.f.syms
else
return keys(integrator.u[1])
syms = variable_symbols(integrator)
if isempty(syms)
syms = keys(integrator.u)
end
return syms
end

function getindepsym(integrator::DEIntegrator)
if has_indepsym(integrator.f)
return integrator.f.indepsym
else
syms = independent_variable_symbols(integrator)
if isempty(syms)
return nothing
end
return syms
end

function getparamsyms(integrator::DEIntegrator)
if has_paramsyms(integrator.f)
return integrator.f.paramsyms
else
psyms = parameter_symbols(integrator)
if isempty(psyms)
return nothing
end
return psyms
end

function getobserved(integrator::DEIntegrator)
Expand All @@ -417,58 +417,76 @@ function getobserved(integrator::DEIntegrator)
end

function sym_to_index(sym, integrator::DEIntegrator)
if has_sys(integrator.f) && is_state_sym(integrator.f.sys, sym)
return state_sym_to_index(integrator.f.sys, sym)
idx = variable_index(integrator, sym)
if idx === nothing
idx = findfirst(isequal(sym), keys(integrator.u))
end
return idx
end

# SymbolicIndexingInterface
SymbolicIndexingInterface.symbolic_container(A::DEIntegrator) = A.f
SymbolicIndexingInterface.parameter_values(A::DEIntegrator) = A.p

function SymbolicIndexingInterface.is_observed(A::DEIntegrator, sym)
return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic()
end

function SymbolicIndexingInterface.observed(A::DEIntegrator, sym)
(u, p, t) -> getobserved(A)(sym, u, p, t)
end

SymbolicIndexingInterface.is_time_dependent(::DEIntegrator) = true

# TODO make this nontrivial once dynamic state selection works
SymbolicIndexingInterface.constant_structure(::DEIntegrator) = true

function Base.getproperty(A::DEIntegrator, sym::Symbol)
if sym === :destats && hasfield(typeof(A), :stats)
@warn "destats has been deprecated for stats"
getfield(A, :stats)
else
return sym_to_index(sym, getsyms(integrator))
return getfield(A, sym)
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator,
I::Union{Int, AbstractArray{Int},
Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int},
CartesianIndex, Colon, BitArray,
AbstractArray{Bool}}...)
RecursiveArrayTools.VectorOfArray(A.u)[I...]
A.u[I...]
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
elseif is_independent_variable(A, sym)
return A.t
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, A.p, A.t)
else
error("Tried to index integrator with a Symbol that was not found in the system.")
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ArraySymbolic, sym)
return A[collect(sym)]
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
return getindex.((A,), sym)
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
if issymbollike(sym)
if sym isa AbstractArray
return A[collect(sym)]
end
i = sym_to_index(sym, A)
elseif all(issymbollike, sym)
return getindex.((A,), sym)
else
i = sym
end
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))

if i === nothing
if issymbollike(sym)
if has_sys(A.f) && is_indep_sym(A.f.sys, sym) ||
Symbol(sym) == getindepsym(A)
return A.t
elseif has_sys(A.f) && is_param_sym(A.f.sys, sym)
return A.p[param_sym_to_index(A.f.sys, sym)]
elseif has_paramsyms(A.f) && Symbol(sym) in getparamsyms(A)
return A.p[findfirst(x -> isequal(x, Symbol(sym)), getparamsyms(A))]
elseif (sym isa Symbol) && has_sys(A.f) && hasproperty(A.f.sys, sym) # Handles input like :X (where X is a state).
return observed(A, getproperty(A.f.sys, sym))
elseif has_sys(A.f) && (count('₊', String(Symbol(sym))) == 1) &&
(count(isequal(Symbol(sym)),
Symbol.(A.f.sys.name, :₊, getparamsyms(A))) == 1) # Handles input like sys.X (where X is a parameter).
return A.p[findfirst(isequal(Symbol(sym)),
Symbol.(A.f.sys.name, :₊, getparamsyms(A)))]
else
return observed(A, sym)
end
else
observed(A, sym)
end
elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer}
A[i]
if symtype != NotSymbolic()
return getindex(A, symtype, sym)
else
error("Invalid indexing of integrator")
return getindex(A, elsymtype, sym)
end
end

Expand All @@ -477,52 +495,24 @@ function observed(A::DEIntegrator, sym)
end

function Base.setindex!(A::DEIntegrator, val, sym)
if has_sys(A.f)
if issymbollike(sym)
params = getparamsyms(A)
s = Symbol.(states(A.f.sys))
params = Symbol.(params)

i = findfirst(isequal(Symbol(sym)), s)
if !isnothing(i)
A.u[i] = val
return A
elseif sym isa Symbol # Handles input like :X.
s_f = Symbol.(getproperty.(states(A.f.sys), :f))
if count(isequal(Symbol(sym)), s_f) == 1
i = findfirst(isequal(sym), s_f)
A.u[i] = val
return A
elseif count(isequal(Symbol(sym)), s_f) > 1
error("The input symbol $(sym) occurs several times among integrator states. Please avoid use Symbol form (:$(sym)).")
end
elseif count('₊', String(Symbol(sym))) == 1 # Handles input like sys.X.
s_names = Symbol.(A.f.sys.name, :₊, s)
if count(isequal(Symbol(sym)), s_names) == 1
i = findfirst(isequal(Symbol(sym)), s_names)
A.u[i] = val
return A
end
end

i = findfirst(isequal(Symbol(sym)), params)
if !isnothing(i)
A.p[i] = val
return A
elseif count('₊', String(Symbol(sym))) == 1 # Handles input like sys.X.
p_names = Symbol.(A.f.sys.name, :₊, params)
if count(isequal(Symbol(sym)), p_names) == 1
i = findfirst(isequal(Symbol(sym)), p_names)
A.p[i] = val
return A
end
end
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
has_sys(A.f) || error("Invalid indexing of integrator: Integrator does not support indexing without a system")
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(A, sym)
A.u[variable_index(A, sym)] = val
elseif is_parameter(A, sym)
Base.depwarn("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.", :parameter_setindex)
setp(A, sym)(A, val)
else
error("Invalid indexing of integrator: $sym is not a symbol")
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
end
return A
elseif symbolic_type(sym) == ArraySymbolic()
setindex!.((A,), val, collect(sym))
return A
else
error("Invalid indexing of integrator: Integrator does not support indexing without a system")
sym isa AbstractArray || error("Invalid indexing of integrator")
setindex!.((A,), val, sym)
return A
end
end

Expand Down
Loading