Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cache stripping to solution stripping #2393

Merged
merged 12 commits into from
Aug 22, 2024
26 changes: 26 additions & 0 deletions lib/OrdinaryDiffEqCore/src/interp_func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,29 @@ function InterpolationData(id::InterpolationData, f)
id.differential_vars,
id.sensitivitymode)
end

# strip interpolation of function information
function SciMLBase.strip_interpolation(id::InterpolationData)

cache = strip_cache(id.cache)

InterpolationData(nothing, id.timeseries,
id.ts,
id.ks,
id.alg_choice,
id.dense,
cache,
id.differential_vars,
id.sensitivitymode)
end

function strip_cache(cache)
if hasfield(typeof(cache), :jac_config) || hasfield(typeof(cache), :grad_config)
fieldnums = length(fieldnames(typeof(cache)))
noth_list = fill(nothing,fieldnums)
cache_type_name = Base.typename(typeof(cache)).wrapper
cache_type_name(noth_list...)
else
cache
end
end
2 changes: 1 addition & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ import OrdinaryDiffEqCore: trivial_limiter!, CompositeAlgorithm, alg_order,
_change_t_via_interpolation!, ODEIntegrator, _ode_interpolant!,
current_interpolant, resize_nlsolver!, _ode_interpolant,
handle_tstop!, _postamble!, update_uprev!, resize_J_W!,
DAEAlgorithm, get_fsalfirstlast
DAEAlgorithm, get_fsalfirstlast, strip_cache, strip_interpolation

export CompositeAlgorithm, ShampineCollocationInit, BrownFullBasicInit, NoInit
AutoSwitch
Expand Down
17 changes: 17 additions & 0 deletions test/interface/ode_strip_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using OrdinaryDiffEq, Test
import SciMLBase

function lorenz!(du, u, p, t)
du[1] = 10.0 * (u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

u0 = [1.0; 0.0; 0.0]
tspan = (0.0, 0.5)
prob = ODEProblem(lorenz!, u0, tspan)

sol = solve(prob, Rosenbrock23())

@test isnothing(SciMLBase.strip_interpolation(sol.interp).f)
@test isnothing(SciMLBase.strip_interpolation(sol.interp).cache.jac_config)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ end
@time @safetestset "Inplace Interpolation Tests" include("interface/inplace_interpolation.jl")
@time @safetestset "Algebraic Interpolation Tests" include("interface/algebraic_interpolation.jl")
@time @safetestset "Default Solver Tests" include("interface/default_solver_tests.jl")
@time @safetestset "Interpolation and Cache Stripping Tests" include("interface/ode_strip_test.jl")
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "InterfaceII" || GROUP == "Interface")
Expand Down
Loading