Skip to content

Commit

Permalink
feat: support passing symbolic variables to save_idxs
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 7, 2024
1 parent 457611c commit 1103634
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
2 changes: 2 additions & 0 deletions lib/OrdinaryDiffEqCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[weakdeps]
Expand Down Expand Up @@ -72,6 +73,7 @@ SimpleUnPack = "1"
Static = "0.8, 1"
StaticArrayInterface = "1.2"
StaticArraysCore = "1.0"
SymbolicIndexingInterface = "0.3.31"
Test = "<0.0.1, 1"
TruncatedStacktraces = "1.2"
julia = "1.10"
Expand Down
2 changes: 2 additions & 0 deletions lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ import ADTypes: AutoFiniteDiff, AutoForwardDiff

using SciMLStructures: canonicalize, Tunable, isscimlstructure

using SymbolicIndexingInterface: parameter_values, is_variable, variable_index, symbolic_type, NotSymbolic

const CompiledFloats = Union{Float32, Float64}
import Preferences

Expand Down
24 changes: 23 additions & 1 deletion lib/OrdinaryDiffEqCore/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,28 @@ function DiffEqBase.__init(
end

### Algorithm-specific defaults ###
if save_idxs === nothing
saved_subsystem = nothing
else
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
_save_idxs = [save_idxs]
else
_save_idxs = save_idxs
end
saved_subsystem = SciMLBase.SavedSubsystem(prob, parameter_values(prob), _save_idxs)
_save_idxs = map(Base.Fix1(variable_index, prob), filter(Base.Fix1(is_variable, prob), _save_idxs))
if isempty(_save_idxs)
# no states to save
save_idxs = Int[]
elseif !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
# only a single state to save, and save it as a scalar timeseries instead of
# single-element array
save_idxs = only(_save_idxs)
else
save_idxs = _save_idxs
end
end

if save_idxs === nothing
ksEltype = Vector{rateType}
else
Expand Down Expand Up @@ -424,7 +446,7 @@ function DiffEqBase.__init(
f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars, false)
sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries,
dense = dense, k = ks, interp = id, alg_choice = alg_choice,
calculate_error = false, stats = stats)
calculate_error = false, stats = stats, saved_subsystem = saved_subsystem)

if recompile_flag == true
FType = typeof(f)
Expand Down

0 comments on commit 1103634

Please sign in to comment.