From a6a4a1fa5ef786d221deaae2c2f4e3fea098628b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 28 Nov 2024 16:05:01 +0530 Subject: [PATCH] feat: do not require `nlsolve_alg` for trivial `OverrideInit` --- src/initialization.jl | 37 +++++++++++++++++++++++-------------- test/initialization.jl | 26 ++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/src/initialization.jl b/src/initialization.jl index 993d10686..9f5274eae 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -171,6 +171,9 @@ Keyword arguments: provided to the `OverrideInit` constructor takes priority over this keyword argument. If the former is `nothing`, this keyword argument will be used. If it is also not provided, an error will be thrown. + +In case the initialization problem is trivial, `nlsolve_alg`, `abstol` and `reltol` are +not required. """ function get_initial_values(prob, valp, f, alg::OverrideInit, iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...) @@ -193,26 +196,32 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, initdata.update_initializeprob!(initprob, valp) end - if alg.abstol !== nothing - _abstol = alg.abstol - elseif abstol !== nothing - _abstol = abstol - else - throw(OverrideInitNoTolerance(:abstol)) - end - if alg.reltol !== nothing - _reltol = alg.reltol - elseif reltol !== nothing - _reltol = reltol + if state_values(initprob) === nothing + nlsol = initprob + success = true else - throw(OverrideInitNoTolerance(:reltol)) + if alg.abstol !== nothing + _abstol = alg.abstol + elseif abstol !== nothing + _abstol = abstol + else + throw(OverrideInitNoTolerance(:abstol)) + end + if alg.reltol !== nothing + _reltol = alg.reltol + elseif reltol !== nothing + _reltol = reltol + else + throw(OverrideInitNoTolerance(:reltol)) + end + nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol) + success = SciMLBase.successful_retcode(nlsol) end - nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol) u0 = initdata.initializeprobmap(nlsol) if initdata.initializeprobpmap !== nothing p = initdata.initializeprobpmap(valp, nlsol) end - return u0, p, SciMLBase.successful_retcode(nlsol) + return u0, p, success end diff --git a/test/initialization.jl b/test/initialization.jl index e6211e59c..96fa965fd 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -229,4 +229,30 @@ end @test p ≈ 0.0 @test success end + + @testset "Trivial initialization" begin + initprob = NonlinearProblem(Returns(nothing), nothing, [1.0]) + update_initializeprob! = function (iprob, integ) + iprob.p[1] = integ.u[1] + end + initprobmap = function (nlsol) + u1 = parameter_values(nlsol)[1] + return [u1, u1] + end + initprobpmap = function (_, nlsol) + return 0.0 + end + initialization_data = SciMLBase.OverrideInitData( + initprob, update_initializeprob!, initprobmap, initprobpmap) + fn = ODEFunction(rhs2; initialization_data) + prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0) + integ = init(prob; initializealg = NoInit()) + + u0, p, success = SciMLBase.get_initial_values( + prob, integ, fn, SciMLBase.OverrideInit(), Val(false) + ) + @test u0 ≈ [2.0, 2.0] + @test p ≈ 0.0 + @test success + end end