Skip to content

Commit

Permalink
feat: update RecursiveArrayTools and SymbolicIndexingInterface compat
Browse files Browse the repository at this point in the history
- Implement SII for solutions, problems, SciMLFunctions, integrators
- Solution indexing falls back to AbstractVectorOfArray indexing
- remove issymbollike
- remove has_static_variable
- add new parameter indexing
- fix adjoint definitions
- Remove syms, paramsyms, indepsym from SciMLFunctions, use SymbolCache in place of
  f.sys instead
  • Loading branch information
AayushSabharwal committed Dec 12, 2023
1 parent 9fdb9c8 commit bbba9c6
Show file tree
Hide file tree
Showing 18 changed files with 625 additions and 891 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ PyCall = "1.96"
PythonCall = "0.9.15"
RCall = "0.13.18"
RecipesBase = "1.0"
RecursiveArrayTools = "2.38"
RecursiveArrayTools = "3.0"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5"
SciMLOperators = "0.3.7"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.9"
SymbolicIndexingInterface = "0.2"
SymbolicIndexingInterface = "0.3"
Tables = "1.11"
TruncatedStacktraces = "1.4"
QuasiMonteCarlo = "0.2.19, 0.3"
Expand Down
4 changes: 2 additions & 2 deletions ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{
sym,
j::Integer)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym

Check warning on line 15 in ext/SciMLBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseChainRulesCoreExt.jl#L15

Added line #L15 was not covered by tests
if i === nothing
getter = getobserved(VA)
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
Expand Down Expand Up @@ -65,7 +65,7 @@ end

function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym

Check warning on line 68 in ext/SciMLBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseChainRulesCoreExt.jl#L68

Added line #L68 was not covered by tests
if i === nothing
throw(error("AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Expand Down
9 changes: 5 additions & 4 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ using Zygote
using Zygote: @adjoint, pullback
import Zygote: literal_getproperty
using SciMLBase
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake,
using SciMLBase: ODESolution, sym_to_index, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
using SymbolicIndexingInterface: symbolic_type, NotSymbolic

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand All @@ -32,7 +33,7 @@ end

@adjoint function getindex(VA::ODESolution, sym, j::Int)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym

Check warning on line 36 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L36

Added line #L36 was not covered by tests
du, dprob = if i === nothing
getter = getobserved(VA)
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
Expand Down Expand Up @@ -81,7 +82,7 @@ end
for (x, j) in zip(VA.u, 1:length(VA))]
(Δ′, nothing)
end
VA[i], ODESolution_getindex_pullback
VA[:, i], ODESolution_getindex_pullback

Check warning on line 85 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L85

Added line #L85 was not covered by tests
end

@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution,
Expand All @@ -91,7 +92,7 @@ end

@adjoint function getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym

Check warning on line 95 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L95

Added line #L95 was not covered by tests
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Expand Down
20 changes: 5 additions & 15 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ end # Vector of some type which is not an array
function EnsembleSolution(sim::T, elapsedTime,
converged, stats=nothing) where {T <: AbstractVector{T2}
} where {T2 <:
AbstractArray}
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1, typeof(sim)}(
sim,
Union{AbstractArray,RecursiveArrayTools.AbstractVectorOfArray}}
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1,

Check warning on line 44 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L44

Added line #L44 was not covered by tests
typeof(sim)}(sim,
elapsedTime,
converged,
stats)
Expand Down Expand Up @@ -209,18 +209,8 @@ end
end
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return [xi[s] for xi in x]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution,
::Colon,
args::Colon...)
return invoke(getindex,
Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...},
x,
:,
args...)
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Union{ScalarSymbolic,ArraySymbolic}, s, ::Colon)
return [xi[s] for xi in x.u]

Check warning on line 213 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L212-L213

Added lines #L212 - L213 were not covered by tests
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
Expand Down
184 changes: 87 additions & 97 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,13 @@ function set_u!(integrator::DEIntegrator, sym, val)
# So any error checking happens to ensure we actually _can_ set state
set_u!(integrator, integrator.u)

if !issymbollike(sym)
if symbolic_type(sym) == NotSymbolic()

Check warning on line 348 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L348

Added line #L348 was not covered by tests
error("sym must be a symbol")
end
i = sym_to_index(sym, integrator)
i = variable_index(integrator, sym)

Check warning on line 351 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L351

Added line #L351 was not covered by tests

if isnothing(i)
error("sym is not a state variable")
error("$sym is not a state variable")

Check warning on line 354 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L354

Added line #L354 was not covered by tests
end

integrator.u[i] = val
Expand Down Expand Up @@ -385,27 +385,27 @@ end

### Indexing
function getsyms(integrator::DEIntegrator)
if has_syms(integrator.f)
return integrator.f.syms
else
return keys(integrator.u[1])
syms = variable_symbols(integrator)
if isempty(syms)
syms = keys(integrator.u)

Check warning on line 390 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L388-L390

Added lines #L388 - L390 were not covered by tests
end
return syms

Check warning on line 392 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L392

Added line #L392 was not covered by tests
end

function getindepsym(integrator::DEIntegrator)
if has_indepsym(integrator.f)
return integrator.f.indepsym
else
syms = independent_variable_symbols(integrator)
if isempty(syms)

Check warning on line 397 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L396-L397

Added lines #L396 - L397 were not covered by tests
return nothing
end
return syms

Check warning on line 400 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L400

Added line #L400 was not covered by tests
end

function getparamsyms(integrator::DEIntegrator)
if has_paramsyms(integrator.f)
return integrator.f.paramsyms
else
psyms = parameter_symbols(integrator)
if isempty(psyms)

Check warning on line 405 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L404-L405

Added lines #L404 - L405 were not covered by tests
return nothing
end
return psyms

Check warning on line 408 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L408

Added line #L408 was not covered by tests
end

function getobserved(integrator::DEIntegrator)
Expand All @@ -417,58 +417,76 @@ function getobserved(integrator::DEIntegrator)
end

function sym_to_index(sym, integrator::DEIntegrator)
if has_sys(integrator.f) && is_state_sym(integrator.f.sys, sym)
return state_sym_to_index(integrator.f.sys, sym)
idx = variable_index(integrator, sym)
if idx === nothing
idx = findfirst(isequal(sym), keys(integrator.u))

Check warning on line 422 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L420-L422

Added lines #L420 - L422 were not covered by tests
end
return idx

Check warning on line 424 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L424

Added line #L424 was not covered by tests
end

# SymbolicIndexingInterface
SymbolicIndexingInterface.symbolic_container(A::DEIntegrator) = A.f
SymbolicIndexingInterface.parameter_values(A::DEIntegrator) = A.p

Check warning on line 429 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L428-L429

Added lines #L428 - L429 were not covered by tests

function SymbolicIndexingInterface.is_observed(A::DEIntegrator, sym)
return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic()

Check warning on line 432 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L431-L432

Added lines #L431 - L432 were not covered by tests
end

function SymbolicIndexingInterface.observed(A::DEIntegrator, sym)
(u, p, t) -> getobserved(A)(sym, u, p, t)

Check warning on line 436 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L435-L436

Added lines #L435 - L436 were not covered by tests
end

SymbolicIndexingInterface.is_time_dependent(::DEIntegrator) = true

Check warning on line 439 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L439

Added line #L439 was not covered by tests

# TODO make this nontrivial once dynamic state selection works
SymbolicIndexingInterface.constant_structure(::DEIntegrator) = true

Check warning on line 442 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L442

Added line #L442 was not covered by tests

function Base.getproperty(A::DEIntegrator, sym::Symbol)
if sym === :destats && hasfield(typeof(A), :stats)
@warn "destats has been deprecated for stats"
getfield(A, :stats)

Check warning on line 447 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L444-L447

Added lines #L444 - L447 were not covered by tests
else
return sym_to_index(sym, getsyms(integrator))
return getfield(A, sym)

Check warning on line 449 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L449

Added line #L449 was not covered by tests
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator,
I::Union{Int, AbstractArray{Int},
Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int},

Check warning on line 453 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L453

Added line #L453 was not covered by tests
CartesianIndex, Colon, BitArray,
AbstractArray{Bool}}...)
RecursiveArrayTools.VectorOfArray(A.u)[I...]
A.u[I...]

Check warning on line 456 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L456

Added line #L456 was not covered by tests
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
elseif is_independent_variable(A, sym)
return A.t
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, A.p, A.t)

Check warning on line 468 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L459-L468

Added lines #L459 - L468 were not covered by tests
else
error("Tried to index integrator with a Symbol that was not found in the system.")

Check warning on line 470 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L470

Added line #L470 was not covered by tests
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ArraySymbolic, sym)
return A[collect(sym)]

Check warning on line 475 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L474-L475

Added lines #L474 - L475 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
return getindex.((A,), sym)

Check warning on line 479 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L478-L479

Added lines #L478 - L479 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
if issymbollike(sym)
if sym isa AbstractArray
return A[collect(sym)]
end
i = sym_to_index(sym, A)
elseif all(issymbollike, sym)
return getindex.((A,), sym)
else
i = sym
end
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))

Check warning on line 484 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L483-L484

Added lines #L483 - L484 were not covered by tests

if i === nothing
if issymbollike(sym)
if has_sys(A.f) && is_indep_sym(A.f.sys, sym) ||
Symbol(sym) == getindepsym(A)
return A.t
elseif has_sys(A.f) && is_param_sym(A.f.sys, sym)
return A.p[param_sym_to_index(A.f.sys, sym)]
elseif has_paramsyms(A.f) && Symbol(sym) in getparamsyms(A)
return A.p[findfirst(x -> isequal(x, Symbol(sym)), getparamsyms(A))]
elseif (sym isa Symbol) && has_sys(A.f) && hasproperty(A.f.sys, sym) # Handles input like :X (where X is a state).
return observed(A, getproperty(A.f.sys, sym))
elseif has_sys(A.f) && (count('', String(Symbol(sym))) == 1) &&
(count(isequal(Symbol(sym)),
Symbol.(A.f.sys.name, :₊, getparamsyms(A))) == 1) # Handles input like sys.X (where X is a parameter).
return A.p[findfirst(isequal(Symbol(sym)),
Symbol.(A.f.sys.name, :₊, getparamsyms(A)))]
else
return observed(A, sym)
end
else
observed(A, sym)
end
elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer}
A[i]
if symtype != NotSymbolic()
return getindex(A, symtype, sym)

Check warning on line 487 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L486-L487

Added lines #L486 - L487 were not covered by tests
else
error("Invalid indexing of integrator")
return getindex(A, elsymtype, sym)

Check warning on line 489 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L489

Added line #L489 was not covered by tests
end
end

Expand All @@ -477,52 +495,24 @@ function observed(A::DEIntegrator, sym)
end

function Base.setindex!(A::DEIntegrator, val, sym)
if has_sys(A.f)
if issymbollike(sym)
params = getparamsyms(A)
s = Symbol.(states(A.f.sys))
params = Symbol.(params)

i = findfirst(isequal(Symbol(sym)), s)
if !isnothing(i)
A.u[i] = val
return A
elseif sym isa Symbol # Handles input like :X.
s_f = Symbol.(getproperty.(states(A.f.sys), :f))
if count(isequal(Symbol(sym)), s_f) == 1
i = findfirst(isequal(sym), s_f)
A.u[i] = val
return A
elseif count(isequal(Symbol(sym)), s_f) > 1
error("The input symbol $(sym) occurs several times among integrator states. Please avoid use Symbol form (:$(sym)).")
end
elseif count('', String(Symbol(sym))) == 1 # Handles input like sys.X.
s_names = Symbol.(A.f.sys.name, :₊, s)
if count(isequal(Symbol(sym)), s_names) == 1
i = findfirst(isequal(Symbol(sym)), s_names)
A.u[i] = val
return A
end
end

i = findfirst(isequal(Symbol(sym)), params)
if !isnothing(i)
A.p[i] = val
return A
elseif count('', String(Symbol(sym))) == 1 # Handles input like sys.X.
p_names = Symbol.(A.f.sys.name, :₊, params)
if count(isequal(Symbol(sym)), p_names) == 1
i = findfirst(isequal(Symbol(sym)), p_names)
A.p[i] = val
return A
end
end
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
has_sys(A.f) || error("Invalid indexing of integrator: Integrator does not support indexing without a system")
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(A, sym)
A.u[variable_index(A, sym)] = val
elseif is_parameter(A, sym)
Base.depwarn("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.", :parameter_setindex)
setp(A, sym)(A, val)

Check warning on line 504 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L498-L504

Added lines #L498 - L504 were not covered by tests
else
error("Invalid indexing of integrator: $sym is not a symbol")
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")

Check warning on line 506 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L506

Added line #L506 was not covered by tests
end
return A
elseif symbolic_type(sym) == ArraySymbolic()
setindex!.((A,), val, collect(sym))
return A

Check warning on line 511 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L508-L511

Added lines #L508 - L511 were not covered by tests
else
error("Invalid indexing of integrator: Integrator does not support indexing without a system")
sym isa AbstractArray || error("Invalid indexing of integrator")
setindex!.((A,), val, sym)
return A

Check warning on line 515 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L513-L515

Added lines #L513 - L515 were not covered by tests
end
end

Expand Down
Loading

0 comments on commit bbba9c6

Please sign in to comment.