From 2a4558bbcec1d12235044cd0297ec0c8f11d6446 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 30 Sep 2024 10:17:42 -0400 Subject: [PATCH 1/3] Always respect user provided `u0` override --- src/solve.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/solve.jl b/src/solve.jl index 59996ea6e..bbfc26aab 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1172,7 +1172,8 @@ function get_concrete_problem(prob, isadapt; kwargs...) tspan_promote = promote_tspan(u0_promote, p, tspan, prob, kwargs) f_promote = promote_f(prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p, tspan_promote[1]) - if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(prob.u0) && + if isconcreteu0(prob, tspan[1], kwargs) && prob.u0 === u0 && + typeof(u0_promote) === typeof(prob.u0) && prob.tspan == tspan && typeof(prob.tspan) === typeof(tspan_promote) && p === prob.p && f_promote === prob.f return prob From 42291f30391b17805c391790393395e70eacacb0 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 30 Sep 2024 10:25:41 -0400 Subject: [PATCH 2/3] Add test --- test/downstream/prob_kwargs.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/downstream/prob_kwargs.jl b/test/downstream/prob_kwargs.jl index f4d7e2c1f..017466e0c 100644 --- a/test/downstream/prob_kwargs.jl +++ b/test/downstream/prob_kwargs.jl @@ -10,3 +10,7 @@ prob = ODEProblem(lorenz, u0, tspan, alg = Tsit5()) @test_nowarn sol = solve(prob, reltol = 1e-6) sol = solve(prob, reltol = 1e-6) @test sol.alg isa Tsit5 + +new_u0 = rand(3) +sol = solve(prob, u0 = new_u0) +@test sol.prob.u0 === new_u0 From 832a359cc52fe1498076225ee4f495700deaad77 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 30 Sep 2024 10:28:16 -0400 Subject: [PATCH 3/3] Format --- src/forwarddiff.jl | 5 ++++- src/integrator_accessors.jl | 9 ++++++--- src/solve.jl | 6 ++++-- test/downstream/tables.jl | 6 +++++- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 6eb3e1c7a..37d82885c 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -347,7 +347,10 @@ function anyeltypedual(x::NamedTuple, ::Type{Val{counter}} = Val{0}) where {coun anyeltypedual(values(x)) end -DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter} = Any +function DiffEqBase.anyeltypedual( + f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter} + Any +end @inline promote_u0(::Nothing, p, t0) = nothing diff --git a/src/integrator_accessors.jl b/src/integrator_accessors.jl index 3a7550918..b89dd3d1b 100644 --- a/src/integrator_accessors.jl +++ b/src/integrator_accessors.jl @@ -1,9 +1,12 @@ # the following are setup per how integrators are implemented in OrdinaryDiffEq and # StochasticDiffEq and provide dispatch points that JumpProcesses and others can use. -get_tstops(integ::DEIntegrator) = +function get_tstops(integ::DEIntegrator) error("get_tstops not implemented for integrators of type $(nameof(typeof(integ)))") -get_tstops_array(integ::DEIntegrator) = +end +function get_tstops_array(integ::DEIntegrator) error("get_tstops_array not implemented for integrators of type $(nameof(typeof(integ)))") -get_tstops_max(integ::DEIntegrator) = +end +function get_tstops_max(integ::DEIntegrator) error("get_tstops_max not implemented for integrators of type $(nameof(typeof(integ)))") +end diff --git a/src/solve.jl b/src/solve.jl index bbfc26aab..c16a5c6cb 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1389,7 +1389,8 @@ function __solve( kwargs...) if second_time throw(NoDefaultAlgorithmError()) - elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm}) + elseif length(args) > 0 && !(first(args) isa + Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm}) throw(NonSolverError()) else __solve(prob, nothing, args...; default_set = false, second_time = true, kwargs...) @@ -1400,7 +1401,8 @@ function __init(prob::AbstractDEProblem, args...; default_set = false, second_ti kwargs...) if second_time throw(NoDefaultAlgorithmError()) - elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm}) + elseif length(args) > 0 && !(first(args) isa + Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm}) throw(NonSolverError()) else __init(prob, nothing, args...; default_set = false, second_time = true, kwargs...) diff --git a/test/downstream/tables.jl b/test/downstream/tables.jl index 7279eb621..84f2ebec4 100644 --- a/test/downstream/tables.jl +++ b/test/downstream/tables.jl @@ -5,7 +5,11 @@ sol1 = solve(prob, Euler(); dt = 1 // 2^(4)); df = DataFrame(sol1) @test names(df) == ["timestamp", "value1", "value2", "value3", "value4"] -prob = ODEProblem(ODEFunction(f_2dlinear, sys = SymbolicIndexingInterface.SymbolCache([:a, :b, :c, :d], [], :t)), rand(2, 2), (0.0, 1.0)); +prob = ODEProblem( + ODEFunction( + f_2dlinear, sys = SymbolicIndexingInterface.SymbolCache([:a, :b, :c, :d], [], :t)), + rand(2, 2), + (0.0, 1.0)); sol2 = solve(prob, Euler(); dt = 1 // 2^(4)); df = DataFrame(sol2) @test names(df) == ["timestamp", "a", "b", "c", "d"]