diff --git a/lib/OrdinaryDiffEqCore/Project.toml b/lib/OrdinaryDiffEqCore/Project.toml index 68579225a2..5f29c546a8 100644 --- a/lib/OrdinaryDiffEqCore/Project.toml +++ b/lib/OrdinaryDiffEqCore/Project.toml @@ -1,7 +1,7 @@ name = "OrdinaryDiffEqCore" uuid = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" authors = ["ParamThakkar123 "] -version = "1.7.0" +version = "1.7.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/lib/OrdinaryDiffEqCore/src/interp_func.jl b/lib/OrdinaryDiffEqCore/src/interp_func.jl index 36debb6883..dfa3dee146 100644 --- a/lib/OrdinaryDiffEqCore/src/interp_func.jl +++ b/lib/OrdinaryDiffEqCore/src/interp_func.jl @@ -75,24 +75,14 @@ function SciMLBase.strip_interpolation(id::InterpolationData) end function strip_cache(cache) - if hasfield(typeof(cache), :jac_config) - SciMLBase.@reset cache.jac_config = nothing + if !(cache isa OrdinaryDiffEqCore.DefaultCache) + cache = SciMLBase.constructorof(typeof(cache))([nothing + for name in fieldnames(typeof(cache))]...) + else + # need to do something special for default cache + cache = OrdinaryDiffEqCore.DefaultCache{Nothing, Nothing, Nothing, Nothing, + Nothing, Nothing, Nothing, Nothing}(nothing, nothing, 0, nothing) end - if hasfield(typeof(cache), :grad_config) - SciMLBase.@reset cache.grad_config = nothing - end - if hasfield(typeof(cache), :nlsolver) - SciMLBase.@reset cache.nlsolver = nothing - end - if hasfield(typeof(cache), :tf) - SciMLBase.@reset cache.tf = nothing - end - if hasfield(typeof(cache), :uf) - SciMLBase.@reset cache.uf = nothing - end - if hasfield(typeof(cache),:args) - SciMLBase.@reset cache.args = nothing - end - + cache end diff --git a/test/interface/ode_strip_test.jl b/test/interface/ode_strip_test.jl index ad5a794bd8..f30a4492aa 100644 --- a/test/interface/ode_strip_test.jl +++ b/test/interface/ode_strip_test.jl @@ -14,7 +14,7 @@ prob = ODEProblem(lorenz!, u0, tspan) rosenbrock_sol = solve(prob, Rosenbrock23()) TRBDF_sol = solve(prob, TRBDF2()) vern_sol = solve(prob, Vern7()) - +default_sol = solve(prob) @testset "Interpolation Stripping" begin @test isnothing(SciMLBase.strip_interpolation(rosenbrock_sol.interp).f) @test isnothing(SciMLBase.strip_interpolation(rosenbrock_sol.interp).cache.jac_config) @@ -22,20 +22,28 @@ vern_sol = solve(prob, Vern7()) end @testset "Rosenbrock Solution Stripping" begin - @test SciMLBase.strip_solution(rosenbrock_sol).prob isa NamedTuple + stripped_sol = SciMLBase.strip_solution(rosenbrock_sol) + @test stripped_sol.prob isa NamedTuple @test isnothing(SciMLBase.strip_solution(rosenbrock_sol, strip_alg = true).alg) - @test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.f) - @test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.cache.jac_config) - @test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.cache.grad_config) - @test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.cache.uf) - @test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.cache.tf) + @test isnothing(stripped_sol.interp.f) + @test isnothing(stripped_sol.interp.cache.jac_config) + @test isnothing(stripped_sol.interp.cache.grad_config) + @test isnothing(stripped_sol.interp.cache.uf) + @test isnothing(stripped_sol.interp.cache.tf) end @testset "TRBDF Solution Stripping" begin - @test SciMLBase.strip_solution(TRBDF_sol).prob isa NamedTuple + stripped_sol = SciMLBase.strip_solution(TRBDF_sol) + @test stripped_sol.prob isa NamedTuple @test isnothing(SciMLBase.strip_solution(TRBDF_sol, strip_alg = true).alg) - @test isnothing(SciMLBase.strip_solution(TRBDF_sol).interp.f) - @test isnothing(SciMLBase.strip_solution(TRBDF_sol).interp.cache.nlsolver) + @test isnothing(stripped_sol.interp.f) + @test isnothing(stripped_sol.interp.cache.nlsolver) +end + +@testset "Default Solution Stripping" begin + stripped_sol = SciMLBase.strip_solution(default_sol) + @test isnothing(stripped_sol.interp.cache.args) + end @test_throws SciMLBase.LazyInterpolationException SciMLBase.strip_solution(vern_sol)