Skip to content

Commit

Permalink
Merge pull request #2486 from AayushSabharwal/as/symbolic-save-idxs
Browse files Browse the repository at this point in the history
feat: support passing symbolic variables to `save_idxs`
  • Loading branch information
ChrisRackauckas authored Nov 4, 2024
2 parents c87f6f2 + 9e45804 commit 1d05b16
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
4 changes: 3 additions & 1 deletion lib/OrdinaryDiffEqCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[weakdeps]
Expand Down Expand Up @@ -69,13 +70,14 @@ Random = "<0.0.1, 1"
RecursiveArrayTools = "2.36, 3"
Reexport = "1.0"
SafeTestsets = "0.1.0"
SciMLBase = "2.56"
SciMLBase = "2.57.2"
SciMLOperators = "0.3"
SciMLStructures = "1"
SimpleUnPack = "1"
Static = "0.8, 1"
StaticArrayInterface = "1.2"
StaticArraysCore = "1.0"
SymbolicIndexingInterface = "0.3.31"
Test = "<0.0.1, 1"
TruncatedStacktraces = "1.2"
julia = "1.10"
Expand Down
2 changes: 2 additions & 0 deletions lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ import Accessors: @reset

using SciMLStructures: canonicalize, Tunable, isscimlstructure

using SymbolicIndexingInterface: parameter_values, is_variable, variable_index, symbolic_type, NotSymbolic

const CompiledFloats = Union{Float32, Float64}
import Preferences

Expand Down
26 changes: 25 additions & 1 deletion lib/OrdinaryDiffEqCore/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,30 @@ function DiffEqBase.__init(
end

### Algorithm-specific defaults ###
if save_idxs === nothing
saved_subsystem = nothing
else
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
_save_idxs = [save_idxs]
else
_save_idxs = save_idxs
end
saved_subsystem = SciMLBase.SavedSubsystem(prob, parameter_values(prob), _save_idxs)
if saved_subsystem !== nothing
_save_idxs = SciMLBase.get_saved_state_idxs(saved_subsystem)
if isempty(_save_idxs)
# no states to save
save_idxs = Int[]
elseif !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
# only a single state to save, and save it as a scalar timeseries instead of
# single-element array
save_idxs = only(_save_idxs)
else
save_idxs = _save_idxs
end
end
end

if save_idxs === nothing
ksEltype = Vector{rateType}
else
Expand Down Expand Up @@ -427,7 +451,7 @@ function DiffEqBase.__init(
f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars, false)
sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries,
dense = dense, k = ks, interp = id, alg_choice = alg_choice,
calculate_error = false, stats = stats)
calculate_error = false, stats = stats, saved_subsystem = saved_subsystem)

if recompile_flag == true
FType = typeof(f)
Expand Down

0 comments on commit 1d05b16

Please sign in to comment.