diff --git a/lib/OrdinaryDiffEqCore/Project.toml b/lib/OrdinaryDiffEqCore/Project.toml index 99b5cc8b23..decba70d05 100644 --- a/lib/OrdinaryDiffEqCore/Project.toml +++ b/lib/OrdinaryDiffEqCore/Project.toml @@ -34,6 +34,7 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [weakdeps] @@ -69,13 +70,14 @@ Random = "<0.0.1, 1" RecursiveArrayTools = "2.36, 3" Reexport = "1.0" SafeTestsets = "0.1.0" -SciMLBase = "2.56" +SciMLBase = "2.57.2" SciMLOperators = "0.3" SciMLStructures = "1" SimpleUnPack = "1" Static = "0.8, 1" StaticArrayInterface = "1.2" StaticArraysCore = "1.0" +SymbolicIndexingInterface = "0.3.31" Test = "<0.0.1, 1" TruncatedStacktraces = "1.2" julia = "1.10" diff --git a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl index 6376270182..a7cbd95167 100644 --- a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl +++ b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl @@ -76,6 +76,8 @@ import Accessors: @reset using SciMLStructures: canonicalize, Tunable, isscimlstructure +using SymbolicIndexingInterface: parameter_values, is_variable, variable_index, symbolic_type, NotSymbolic + const CompiledFloats = Union{Float32, Float64} import Preferences diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index e65a766b11..ef77e4c6b4 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -264,6 +264,30 @@ function DiffEqBase.__init( end ### Algorithm-specific defaults ### + if save_idxs === nothing + saved_subsystem = nothing + else + if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() + _save_idxs = [save_idxs] + else + _save_idxs = save_idxs + end + saved_subsystem = SciMLBase.SavedSubsystem(prob, parameter_values(prob), _save_idxs) + if saved_subsystem !== nothing + _save_idxs = SciMLBase.get_saved_state_idxs(saved_subsystem) + if isempty(_save_idxs) + # no states to save + save_idxs = Int[] + elseif !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() + # only a single state to save, and save it as a scalar timeseries instead of + # single-element array + save_idxs = only(_save_idxs) + else + save_idxs = _save_idxs + end + end + end + if save_idxs === nothing ksEltype = Vector{rateType} else @@ -427,7 +451,7 @@ function DiffEqBase.__init( f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars, false) sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, dense = dense, k = ks, interp = id, alg_choice = alg_choice, - calculate_error = false, stats = stats) + calculate_error = false, stats = stats, saved_subsystem = saved_subsystem) if recompile_flag == true FType = typeof(f)