diff --git a/Project.toml b/Project.toml index 42b9806..4e6a8a4 100644 --- a/Project.toml +++ b/Project.toml @@ -35,8 +35,8 @@ IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" -SparsityTracing = "06eadbd4-12ad-4cbc-ab6e-10f8370940a5" +SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "AlgebraicMultigrid", "DiffEqCallbacks", "ODEProblemLibrary", "DAEProblemLibrary", "ForwardDiff", "SparseDiffTools", "SparsityTracing", "IncompleteLU", "ModelingToolkit"] +test = ["Test", "AlgebraicMultigrid", "DiffEqCallbacks", "ODEProblemLibrary", "DAEProblemLibrary", "ForwardDiff", "SparseDiffTools", "SparseConnectivityTracer", "IncompleteLU", "ModelingToolkit"] diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index 004fbdd..a1cb3c0 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -483,8 +483,12 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i save_everystep = isempty(saveat), save_idxs = nothing, dense = save_everystep, save_on = true, - save_start = true, - save_end = true, + save_start = save_everystep || isempty(saveat) || + saveat isa Number ? true : + prob.tspan[1] in saveat, + save_end = save_everystep || isempty(saveat) || + saveat isa Number ? true : + prob.tspan[2] in saveat, save_timeseries = nothing, progress = false, progress_steps = 1000, @@ -1415,7 +1419,7 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator; early_free = integrator.userfun.p = integrator.p solver_step(integrator, tstop) integrator.t = first(integrator.tout) - # NB: CVode, ARKode may warn and then recover if integrator.t == integrator.tprev so don't flag this as an error + # NB: CVode, ARKode may warn and then recover if integrator.t == integrator.tprev so don't flag this as an error integrator.flag < 0 && break handle_callbacks!(integrator) # this also updates the interpolation integrator.flag < 0 && break diff --git a/test/common_interface/arkode.jl b/test/common_interface/arkode.jl index b303bc9..1b6c7b4 100644 --- a/test/common_interface/arkode.jl +++ b/test/common_interface/arkode.jl @@ -56,3 +56,7 @@ method = ARKODE(Sundials.Explicit(); # Solve sol = solve(prob, method) @test sol.retcode == ReturnCode.Success + +#test that save_start and save_end are false by default when saveat is set +sol = solve(prob, ARKODE(), saveat = [0.1, 0.2]) +@test sol.t == [0.1, 0.2] diff --git a/test/common_interface/cvode.jl b/test/common_interface/cvode.jl index a40996c..a7b1e72 100644 --- a/test/common_interface/cvode.jl +++ b/test/common_interface/cvode.jl @@ -46,7 +46,7 @@ sol = solve(prob, CVODE_Adams(); saveat = saveat, save_everystep = false) @test sol.t == saveat for tstops in [0.9, [0.9]] - sol = solve(prob, CVODE_Adams(); tstops) + local sol = solve(prob, CVODE_Adams(); tstops) @test 0.9 ∈ sol.t end @@ -57,7 +57,7 @@ sol_idxs = solve(prob, CVODE_Adams(); save_idxs = [1], timeseries_errors = false sol_idxs = solve(prob, CVODE_Adams(); save_idxs = [1, 2], timeseries_errors = false, calculate_error = false) -@test length(sol_idxs[1]) == 2 +@test length(sol_idxs[:, 1]) == 2 @test sol[1, :] == sol_idxs[1, :] @test sol[2, :] == sol_idxs[2, :] diff --git a/test/common_interface/ida.jl b/test/common_interface/ida.jl index 7932c28..503e3ec 100644 --- a/test/common_interface/ida.jl +++ b/test/common_interface/ida.jl @@ -60,7 +60,7 @@ sol = solve(prob, IDA(); saveat = saveat, save_everystep = true) @test intersect(sol.t, saveat) == saveat @info "IDA with tstops" for tstops in [0.9, [0.9]] - sol = solve(prob, IDA(); tstops) + local sol = solve(prob, IDA(); tstops) @test 0.9 ∈ sol.t end diff --git a/test/common_interface/jacobians.jl b/test/common_interface/jacobians.jl index 482f174..29f358b 100644 --- a/test/common_interface/jacobians.jl +++ b/test/common_interface/jacobians.jl @@ -35,8 +35,7 @@ sol9 = solve(prob, CVODE_BDF(; linear_solver = :KLU)) @test Array(sol9) ≈ Array(good_sol) Lotka_fj = ODEFunction(Lotka; - jac_prototype = JacVec((du, u) -> Lotka(du, u, (), 0.0), ones(2), - SciMLBase.NullParameters())) + jac_prototype = JacVec(Lotka, ones(2), SciMLBase.NullParameters(), 0.0)) prob = ODEProblem(Lotka_fj, ones(2), (0.0, 10.0)) sol9 = solve(prob, CVODE_BDF(; linear_solver = :GMRES), saveat = 0.1, abstol = 1e-12, diff --git a/test/common_interface/precs.jl b/test/common_interface/precs.jl index 487ef66..058c444 100644 --- a/test/common_interface/precs.jl +++ b/test/common_interface/precs.jl @@ -1,6 +1,6 @@ using Sundials, Test, LinearAlgebra, IncompleteLU import AlgebraicMultigrid -import SparsityTracing, SparseDiffTools +import SparseConnectivityTracer, SparseDiffTools const N = 32 const xyd_brusselator = range(0; stop = 1, length = N) @@ -46,15 +46,14 @@ function init_brusselator_2d(xyd) u end u0 = vec(init_brusselator_2d(xyd_brusselator)) +du = similar(u0) prob_ode_brusselator_2d = ODEProblem(brusselator_2d_vec, u0, (0.0, 11.5), p) -# find Jacobian sparsity pattern -u0_st = SparsityTracing.create_advec(u0) -du_st = similar(u0_st) -brusselator_2d_vec(du_st, u0_st, p, 0.0) -const jaccache = SparsityTracing.jacobian(du_st, length(du_st)) +detector = SparseConnectivityTracer.TracerSparsityDetector() +brus_uf = (du, u)->brusselator_2d_vec(du, u, p, 0.1) +const jaccache = similar(SparseConnectivityTracer.jacobian_sparsity(brus_uf, du, u0, detector), Float64) const W = I - 1.0 * jaccache # setup sparse AD for Jacobian