Skip to content

Commit

Permalink
Merge pull request #2393 from jClugstor/cache_stripping
Browse files Browse the repository at this point in the history
Add cache stripping to solution stripping
  • Loading branch information
ChrisRackauckas authored Aug 22, 2024
2 parents 83c55c9 + c1e3ed1 commit 05f9229
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 1 deletion.
26 changes: 26 additions & 0 deletions lib/OrdinaryDiffEqCore/src/interp_func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,29 @@ function InterpolationData(id::InterpolationData, f)
id.differential_vars,
id.sensitivitymode)
end

# strip interpolation of function information
function SciMLBase.strip_interpolation(id::InterpolationData)

cache = strip_cache(id.cache)

InterpolationData(nothing, id.timeseries,
id.ts,
id.ks,
id.alg_choice,
id.dense,
cache,
id.differential_vars,
id.sensitivitymode)
end

function strip_cache(cache)
if hasfield(typeof(cache), :jac_config) || hasfield(typeof(cache), :grad_config)
fieldnums = length(fieldnames(typeof(cache)))
noth_list = fill(nothing,fieldnums)
cache_type_name = Base.typename(typeof(cache)).wrapper
cache_type_name(noth_list...)
else
cache
end
end
2 changes: 1 addition & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ import OrdinaryDiffEqCore: trivial_limiter!, CompositeAlgorithm, alg_order,
_change_t_via_interpolation!, ODEIntegrator, _ode_interpolant!,
current_interpolant, resize_nlsolver!, _ode_interpolant,
handle_tstop!, _postamble!, update_uprev!, resize_J_W!,
DAEAlgorithm, get_fsalfirstlast
DAEAlgorithm, get_fsalfirstlast, strip_cache, strip_interpolation

export CompositeAlgorithm, ShampineCollocationInit, BrownFullBasicInit, NoInit
AutoSwitch
Expand Down
17 changes: 17 additions & 0 deletions test/interface/ode_strip_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using OrdinaryDiffEq, Test
import SciMLBase

function lorenz!(du, u, p, t)
du[1] = 10.0 * (u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

u0 = [1.0; 0.0; 0.0]
tspan = (0.0, 0.5)
prob = ODEProblem(lorenz!, u0, tspan)

sol = solve(prob, Rosenbrock23())

@test isnothing(SciMLBase.strip_interpolation(sol.interp).f)
@test isnothing(SciMLBase.strip_interpolation(sol.interp).cache.jac_config)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ end
@time @safetestset "Inplace Interpolation Tests" include("interface/inplace_interpolation.jl")
@time @safetestset "Algebraic Interpolation Tests" include("interface/algebraic_interpolation.jl")
@time @safetestset "Default Solver Tests" include("interface/default_solver_tests.jl")
@time @safetestset "Interpolation and Cache Stripping Tests" include("interface/ode_strip_test.jl")
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "InterfaceII" || GROUP == "Interface")
Expand Down

0 comments on commit 05f9229

Please sign in to comment.