From 40aaab2cf28df4a8b5292830a1c0632b3bcad83b Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 8 Nov 2024 13:25:06 -0100 Subject: [PATCH 1/2] Fix remake of ForwardDiffSensitivity Fixes https://github.com/SciML/SciMLSensitivity.jl/issues/1137 --- src/forward_sensitivity.jl | 10 ++++++++-- test/forward.jl | 12 ++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/forward_sensitivity.jl b/src/forward_sensitivity.jl index b6527400a..f23cceed4 100644 --- a/src/forward_sensitivity.jl +++ b/src/forward_sensitivity.jl @@ -675,8 +675,14 @@ function SciMLBase.remake( {uType, tType, isinplace, P, F, K} _p = p === nothing ? parameter_values(prob) : p _f = f === nothing ? prob.f.f : f - _u0 = u0 === nothing ? state_values(prob, 1:(prob.f.numindvar)) : - u0[1:(prob.f.numindvar)] + + if typeof(_f) <: ODEForwardSensitivityFunction + _u0 = u0 === nothing ? state_values(prob, 1:(_f.numindvar)) : + u0[1:(_f.numindvar)] + else + _u0 = u0 === nothing ? state_values(prob) : u0 + end + _tspan = tspan === nothing ? prob.tspan : tspan ODEForwardSensitivityProblem(_f, _u0, _tspan, _p; sensealg = prob.problem_type.sensealg, diff --git a/test/forward.jl b/test/forward.jl index 08158aec3..a76f81aa4 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -271,3 +271,15 @@ f = prob.f @assert f isa ODEForwardSensitivityFunction @test hasproperty(f, :observed) @test f.observed == SciMLBase.DEFAULT_OBSERVED + +# `remake`: https://github.com/SciML/SciMLSensitivity.jl/issues/1137 + +function f(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = dy = -p[3] * u[2] + u[1] * u[2] +end + +p = [1.5, 1.0, 3.0] +ts = (0, 10) +prob = ODEForwardSensitivityProblem(f, [1.0; 1.0], ts, p, sensealg=ForwardDiffSensitivity()) +sol = solve(prob, Tsit5()) \ No newline at end of file From 9a8c5b24198eede160fff1bea3bd560b276fe203 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 8 Nov 2024 15:17:31 -0100 Subject: [PATCH 2/2] Update test/forward.jl --- test/forward.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/forward.jl b/test/forward.jl index a76f81aa4..71a819c8b 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -274,12 +274,12 @@ f = prob.f # `remake`: https://github.com/SciML/SciMLSensitivity.jl/issues/1137 -function f(du, u, p, t) +function ff3(du, u, p, t) du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] du[2] = dy = -p[3] * u[2] + u[1] * u[2] end p = [1.5, 1.0, 3.0] ts = (0, 10) -prob = ODEForwardSensitivityProblem(f, [1.0; 1.0], ts, p, sensealg=ForwardDiffSensitivity()) +prob = ODEForwardSensitivityProblem(ff3, [1.0; 1.0], ts, p, sensealg=ForwardDiffSensitivity()) sol = solve(prob, Tsit5()) \ No newline at end of file