Skip to content

Commit

Permalink
Handle save_idxs changing sensitivity solution size
Browse files Browse the repository at this point in the history
Captured downstream
  • Loading branch information
ChrisRackauckas committed Jan 6, 2024
1 parent f96ed6b commit 144e210
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
16 changes: 12 additions & 4 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ function Base.merge(a::DEStats, b::DEStats)
a.naccept + b.naccept,
a.nreject + b.nreject,
max(a.maxeig, b.maxeig),
)
)
end

"""
Expand All @@ -95,8 +95,8 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
- `stats`: statistics of the solver, such as the number of function evaluations required,
number of Jacobians computed, and more.
- `retcode`: the return code from the solver. Used to determine whether the solver solved
successfully, whether it terminated early due to a user-defined callback, or whether it
exited due to an error. For more details, see
successfully, whether it terminated early due to a user-defined callback, or whether it
exited due to an error. For more details, see
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
"""
struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, S,
Expand Down Expand Up @@ -387,7 +387,15 @@ end

function sensitivity_solution(sol::ODESolution, u, t)
T = eltype(eltype(u))
N = length((size(sol.prob.u0)..., length(u)))

# handle save_idxs
u0 = first(u)
if u0 isa Number
N = 1
else
N = length((size(prob.u0)..., length(u)))
end

interp = if sol.interp isa LinearInterpolation
LinearInterpolation(t, u)
elseif sol.interp isa ConstantInterpolation
Expand Down
10 changes: 9 additions & 1 deletion src/solutions/rode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,15 @@ end

function sensitivity_solution(sol::AbstractRODESolution, u, t)
T = eltype(eltype(u))
N = length((size(sol.prob.u0)..., length(u)))

# handle save_idxs
u0 = first(u)
if u0 isa Number
N = 1
else
N = length((size(prob.u0)..., length(u)))
end

interp = if sol.interp isa LinearInterpolation
LinearInterpolation(t, u)
elseif sol.interp isa ConstantInterpolation
Expand Down

0 comments on commit 144e210

Please sign in to comment.