From 4be5d4a28102f36f6025080c358b3118a0fd8999 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 15 Nov 2023 11:39:43 -0500 Subject: [PATCH 1/4] `finalize` callbacks. fixes https://github.com/SciML/DiffEqBase.jl/issues/931 (although we probably need a similar line added for `Sundials.jl`) --- src/integrators/integrator_utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/integrators/integrator_utils.jl b/src/integrators/integrator_utils.jl index 6d04cc7762..4bc032d1a6 100644 --- a/src/integrators/integrator_utils.jl +++ b/src/integrators/integrator_utils.jl @@ -143,6 +143,7 @@ end postamble!(integrator::ODEIntegrator) = _postamble!(integrator) function _postamble!(integrator) + DiffEqBase.finalize!(integrator.opts.callback, integrator.u, integrator.t, integrator) solution_endpoint_match_cur_integrator!(integrator) resize!(integrator.sol.t, integrator.saveiter) resize!(integrator.sol.u, integrator.saveiter) From f1ba09cd9b341c616f9f714f48a2613df0a70c02 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 15 Nov 2023 12:58:52 -0500 Subject: [PATCH 2/4] add test --- test/integrators/discrete_callback_dual_test.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/integrators/discrete_callback_dual_test.jl b/test/integrators/discrete_callback_dual_test.jl index d37ae7b916..c7376796c1 100644 --- a/test/integrators/discrete_callback_dual_test.jl +++ b/test/integrators/discrete_callback_dual_test.jl @@ -7,11 +7,11 @@ using OrdinaryDiffEq, Test, ForwardDiff u0 = 1.0 tspan = (0.0, 1.0) p = 1.0 - +X = 0 function stopping_cb(tstop) condition = (u, t, integrator) -> t == tstop affect! = integrator -> (println("Stopped!"); integrator.p = zero(integrator.p)) - DiscreteCallback(condition, affect!) + DiscreteCallback(condition, affect!, finalize=(args...)->X+=1) end function test_fun(tstop) @@ -22,6 +22,10 @@ function test_fun(tstop) end @test ForwardDiff.derivative(test_fun, 0.5) ≈ exp(0.5) * u0 # Analytical solution: exp(tstop)*u0 +@test X == 1 # test that finalize callback ran exactly once +test_fun(.5) +@test X == 2 # test that finalize callback ran again + function test_fun(tstop) DualT = typeof(tstop) From 87dabc5779fe29e8740f1d26fb60c83f505d5f69 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 15 Nov 2023 13:07:48 -0500 Subject: [PATCH 3/4] update test. --- test/integrators/discrete_callback_dual_test.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/integrators/discrete_callback_dual_test.jl b/test/integrators/discrete_callback_dual_test.jl index c7376796c1..329d6db7d3 100644 --- a/test/integrators/discrete_callback_dual_test.jl +++ b/test/integrators/discrete_callback_dual_test.jl @@ -7,11 +7,11 @@ using OrdinaryDiffEq, Test, ForwardDiff u0 = 1.0 tspan = (0.0, 1.0) p = 1.0 -X = 0 +times_finalize_called = 0 function stopping_cb(tstop) condition = (u, t, integrator) -> t == tstop affect! = integrator -> (println("Stopped!"); integrator.p = zero(integrator.p)) - DiscreteCallback(condition, affect!, finalize=(args...)->X+=1) + DiscreteCallback(condition, affect!, finalize=(args...)->times_finalize_called+=1) end function test_fun(tstop) @@ -22,9 +22,9 @@ function test_fun(tstop) end @test ForwardDiff.derivative(test_fun, 0.5) ≈ exp(0.5) * u0 # Analytical solution: exp(tstop)*u0 -@test X == 1 # test that finalize callback ran exactly once +@test times_finalize_called == 1 # test that finalize callback ran exactly once test_fun(.5) -@test X == 2 # test that finalize callback ran again +@test times_finalize_called == 2 # test that finalize callback ran again function test_fun(tstop) From 4a534a862f5fce38d7cc1a21c4b79af34abf93ff Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 15 Nov 2023 15:45:04 -0500 Subject: [PATCH 4/4] typo --- test/integrators/discrete_callback_dual_test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integrators/discrete_callback_dual_test.jl b/test/integrators/discrete_callback_dual_test.jl index 329d6db7d3..2e87eae966 100644 --- a/test/integrators/discrete_callback_dual_test.jl +++ b/test/integrators/discrete_callback_dual_test.jl @@ -11,7 +11,7 @@ times_finalize_called = 0 function stopping_cb(tstop) condition = (u, t, integrator) -> t == tstop affect! = integrator -> (println("Stopped!"); integrator.p = zero(integrator.p)) - DiscreteCallback(condition, affect!, finalize=(args...)->times_finalize_called+=1) + DiscreteCallback(condition, affect!, finalize=(args...)->global times_finalize_called+=1) end function test_fun(tstop)