diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 15b7aa916..01ad56a48 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -85,4 +85,4 @@ jobs: with: file: lcov.info token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true + fail_ci_if_error: false diff --git a/src/ODE_nlsolve.jl b/src/ODE_nlsolve.jl index 0e45d2d0e..cf8a5fa74 100644 --- a/src/ODE_nlsolve.jl +++ b/src/ODE_nlsolve.jl @@ -43,4 +43,3 @@ struct ODE_NLProbData{NLProb, UNLProb, NLProbMap, NLProbPmap} """ nlprobpmap::NLProbPmap end - diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 5ab13494c..5cbc1b4a8 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -21,7 +21,7 @@ import CommonSolve: solve, init, step!, solve! import FunctionWrappersWrappers import RuntimeGeneratedFunctions import EnumX -import ADTypes: AbstractADType +import ADTypes: ADTypes, AbstractADType import Accessors: @set, @reset using Expronicon.ADT: @match @@ -351,7 +351,16 @@ struct CheckInit <: DAEInitializationAlgorithm end """ $(TYPEDEF) """ -struct OverrideInit <: DAEInitializationAlgorithm end +struct OverrideInit{T1, T2, F} <: DAEInitializationAlgorithm + abstol::T1 + reltol::T2 + nlsolve::F +end + +function OverrideInit(; abstol = nothing, reltol = nothing, nlsolve = nothing) + OverrideInit(abstol, reltol, nlsolve) +end +OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing) # PDE Discretizations diff --git a/src/initialization.jl b/src/initialization.jl index 58610269b..778843382 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -68,17 +68,26 @@ function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm) "OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.") end +struct OverrideInitNoTolerance <: Exception + tolerance::Symbol +end + +function Base.showerror(io::IO, e::OverrideInitNoTolerance) + print(io, + "Tolerances were not provided to `OverrideInit`. `$(e.tolerance)` must be provided as a keyword argument to `get_initial_values` or as a keyword argument to the `OverrideInit` constructor.") +end + """ -Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if +Utility function to evaluate the RHS, using the integrator's `tmp_cache` if it is in-place or simply calling the function if not. """ -function _evaluate_f_ode(integrator, f, isinplace::Val{true}, args...) +function _evaluate_f(integrator, f, isinplace::Val{true}, args...) tmp = first(get_tmp_cache(integrator)) f(tmp, args...) return tmp end -function _evaluate_f_ode(integrator, f, isinplace::Val{false}, args...) +function _evaluate_f(integrator, f, isinplace::Val{false}, args...) return f(args...) end @@ -98,10 +107,16 @@ _vec(v::AbstractVector) = v Check if the algebraic constraints are satisfied, and error if they aren't. Returns the `u0` and `p` as-is, and is always successful if it returns. Valid only for -`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument. +`AbstractDEProblem` and `AbstractDAEProblem`. Requires a `DEIntegrator` as its second argument. + +Keyword arguments: +- `abstol`: The absolute value below which the norm of the residual of algebraic equations + should lie. The norm function used is `integrator.opts.internalnorm` if present, and + `LinearAlgebra.norm` if not. """ -function get_initial_values(prob::AbstractODEProblem, integrator, f, alg::CheckInit, - isinplace::Union{Val{true}, Val{false}}; kwargs...) +function get_initial_values( + prob::AbstractDEProblem, integrator::DEIntegrator, f, alg::CheckInit, + isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...) u0 = state_values(integrator) p = parameter_values(integrator) t = current_time(integrator) @@ -109,42 +124,32 @@ function get_initial_values(prob::AbstractODEProblem, integrator, f, alg::CheckI algebraic_vars = [all(iszero, x) for x in eachcol(M)] algebraic_eqs = [all(iszero, x) for x in eachrow(M)] - (iszero(algebraic_vars) || iszero(algebraic_eqs)) && return + (iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true update_coefficients!(M, u0, p, t) - tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t) + tmp = _evaluate_f(integrator, f, isinplace, u0, p, t) tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) - normresid = integrator.opts.internalnorm(tmp, t) - if normresid > integrator.opts.abstol - throw(CheckInitFailureError(normresid, integrator.opts.abstol)) + normresid = isdefined(integrator.opts, :internalnorm) ? + integrator.opts.internalnorm(tmp, t) : norm(tmp) + if normresid > abstol + throw(CheckInitFailureError(normresid, abstol)) end return u0, p, true end -""" -Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if -it is in-place or simply calling the function if not. -""" -function _evaluate_f_dae(integrator, f, isinplace::Val{true}, args...) - tmp = get_tmp_cache(integrator)[2] - f(tmp, args...) - return tmp -end - -function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...) - return f(args...) -end - -function get_initial_values(prob::AbstractDAEProblem, integrator, f, alg::CheckInit, - isinplace::Union{Val{true}, Val{false}}; kwargs...) +function get_initial_values( + prob::AbstractDAEProblem, integrator::DEIntegrator, f, alg::CheckInit, + isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...) u0 = state_values(integrator) p = parameter_values(integrator) t = current_time(integrator) - resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t) - normresid = integrator.opts.internalnorm(resid, t) - if normresid > integrator.opts.abstol - throw(CheckInitFailureError(normresid, integrator.opts.abstol)) + resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t) + normresid = isdefined(integrator.opts, :internalnorm) ? + integrator.opts.internalnorm(resid, t) : norm(resid) + + if normresid > abstol + throw(CheckInitFailureError(normresid, abstol)) end return u0, p, true end @@ -155,12 +160,19 @@ end Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and `p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`. If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is. -The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword -argument, failing which this function will throw an error. The success value returned -depends on the success of the nonlinear solve. + +The success value returned depends on the success of the nonlinear solve. + +Keyword arguments: +- `nlsolve_alg`: The NonlinearSolve.jl algorithm to use. If not provided, this function will + throw an error. +- `abstol`, `reltol`: The `abstol` (`reltol`) to use for the nonlinear solve. The value + provided to the `OverrideInit` constructor takes priority over this keyword argument. + If the former is `nothing`, this keyword argument will be used. If it is also not provided, + an error will be thrown. """ function get_initial_values(prob, valp, f, alg::OverrideInit, - isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...) + iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...) u0 = state_values(valp) p = parameter_values(valp) @@ -171,7 +183,8 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, initdata::OverrideInitData = f.initialization_data initprob = initdata.initializeprob - if nlsolve_alg === nothing + nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) + if nlsolve_alg === nothing && state_values(initprob) !== nothing throw(OverrideInitMissingAlgorithm()) end @@ -179,7 +192,21 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, initdata.update_initializeprob!(initprob, valp) end - nlsol = solve(initprob, nlsolve_alg) + if alg.abstol !== nothing + _abstol = alg.abstol + elseif abstol !== nothing + _abstol = abstol + else + throw(OverrideInitNoTolerance(:abstol)) + end + if alg.reltol !== nothing + _reltol = alg.reltol + elseif reltol !== nothing + _reltol = reltol + else + throw(OverrideInitNoTolerance(:reltol)) + end + nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol) u0 = initdata.initializeprobmap(nlsol) if initdata.initializeprobpmap !== nothing diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index afd1dc80a..aeddf301b 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -401,7 +401,8 @@ numerically-defined functions. """ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, O, TCV, - SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip} + SYS, ID <: Union{Nothing, OverrideInitData}, NLP <: Union{Nothing, ODE_NLProbData}} <: + AbstractODEFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -522,7 +523,8 @@ information on generating the SplitFunction from this symbolic engine. """ struct SplitFunction{ iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt, - TPJ, O, TCV, SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip} + TPJ, O, TCV, SYS, ID <: Union{Nothing, OverrideInitData}, + NLP <: Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip} f1::F1 f2::F2 mass_matrix::TMM @@ -2442,7 +2444,7 @@ function ODEFunction{iip, specialize}(f; initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing, initialization_data = __has_initialization_data(f) ? f.initialization_data : nothing, - nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing, + nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing ) where {iip, specialize } @@ -2500,7 +2502,8 @@ function ODEFunction{iip, specialize}(f; typeof(sparsity), Any, Any, typeof(W_prototype), Any, Any, typeof(_colorvec), - typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}( + _f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, observed, _colorvec, sys, initdata, nlprob_data) @@ -2770,7 +2773,8 @@ function SplitFunction{iip, specialize}(f1, f2; if specialize === NoSpecialize SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(f1, f2, mass_matrix, _func_cache, + Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}( + f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, diff --git a/test/downstream/initialization.jl b/test/downstream/initialization.jl new file mode 100644 index 000000000..a7d5ee671 --- /dev/null +++ b/test/downstream/initialization.jl @@ -0,0 +1,61 @@ +using OrdinaryDiffEq, Sundials, SciMLBase, Test + +@testset "CheckInit" begin + abstol = 1e-10 + @testset "Sundials + ODEProblem" begin + function rhs(u, p, t) + return [u[1] * t, u[1]^2 - u[2]^2] + end + function rhs!(du, u, p, t) + du[1] = u[1] * t + du[2] = u[1]^2 - u[2]^2 + end + + oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0]) + iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0]) + + @testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn] + prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0)) + integ = init(prob, Sundials.ARKODE()) + u0, _, success = SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol) + @test success + @test u0 == prob.u0 + + integ.u[2] = 2.0 + @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol) + end + end + + @testset "Sundials + DAEProblem" begin + function daerhs(du, u, p, t) + return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2] + end + function daerhs!(resid, du, u, p, t) + resid[1] = du[1] - u[1] * t - p + resid[2] = u[1]^2 - u[2]^2 + end + + oopfn = DAEFunction{false}(daerhs) + iipfn = DAEFunction{true}(daerhs!) + + @testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn] + prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0) + integ = init(prob, Sundials.IDA()) + u0, _, success = SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol) + @test success + @test u0 == prob.u0 + + integ.u[2] = 2.0 + @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol) + + integ.u[2] = 1.0 + integ.du[1] = 2.0 + @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol) + end + end +end diff --git a/test/initialization.jl b/test/initialization.jl index ca8fb6b6c..4ed011569 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test +using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test @testset "CheckInit" begin @testset "ODEProblem" begin @@ -17,13 +17,15 @@ using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0)) integ = init(prob) u0, _, success = SciMLBase.get_initial_values( - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) @test success @test u0 == prob.u0 integ.u[2] = 2.0 @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) end end @@ -43,18 +45,61 @@ using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0) integ = init(prob, DImplicitEuler()) u0, _, success = SciMLBase.get_initial_values( - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) @test success @test u0 == prob.u0 integ.u[2] = 2.0 @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) integ.u[2] = 1.0 integ.du[1] = 2.0 @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) + end + end + + @testset "SDEProblem" begin + mm_A = [1 0 0; 0 1 0; 0 0 0] + function sdef!(du, u, p, t) + du[1] = u[1] + du[2] = u[2] + du[3] = u[1] + u[2] + u[3] - 1 + end + function sdef(u, p, t) + du = similar(u) + sdef!(du, u, p, t) + du + end + + function g!(du, u, p, t) + @. du = 0.1 + end + function g(u, p, t) + du = similar(u) + g!(du, u, p, t) + du + end + iipfn = SDEFunction{true}(sdef!, g!; mass_matrix = mm_A) + oopfn = SDEFunction{false}(sdef, g; mass_matrix = mm_A) + + @testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn] + prob = SDEProblem(f, [1.0, 1.0, -1.0], (0.0, 1.0)) + integ = init(prob, ImplicitEM()) + u0, _, success = SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) + @test success + @test u0 == prob.u0 + + integ.u[2] = 2.0 + @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) end end end @@ -100,23 +145,54 @@ end prob, integ, fn, SciMLBase.OverrideInit(), Val(false)) end + abstol = 1e-10 + reltol = 1e-10 @testset "Solves" begin - u0, p, success = SciMLBase.get_initial_values( - prob, integ, fn, SciMLBase.OverrideInit(), - Val(false); nlsolve_alg = NewtonRaphson()) + @testset "with explicit alg" begin + u0, p, success = SciMLBase.get_initial_values( + prob, integ, fn, SciMLBase.OverrideInit(), + Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol) - @test u0 ≈ [2.0, 2.0] - @test p ≈ 1.0 - @test success + @test u0 ≈ [2.0, 2.0] + @test p ≈ 1.0 + @test success - initprob.p[1] = 1.0 + initprob.p[1] = 1.0 + end + @testset "with alg in `OverrideInit`" begin + u0, p, success = SciMLBase.get_initial_values( + prob, integ, fn, + SciMLBase.OverrideInit(; nlsolve = NewtonRaphson(), abstol, reltol), + Val(false)) + + @test u0 ≈ [2.0, 2.0] + @test p ≈ 1.0 + @test success + + initprob.p[1] = 1.0 + end + @testset "with trivial problem and no alg" begin + iprob = NonlinearProblem((u, p) -> 0.0, nothing, 1.0) + iprobmap = (_) -> [1.0, 1.0] + initdata = SciMLBase.OverrideInitData(iprob, nothing, iprobmap, nothing) + _fn = ODEFunction(rhs2; initialization_data = initdata) + _prob = ODEProblem(_fn, [2.0, 0.0], (0.0, 1.0), 1.0) + _integ = init(_prob; initializealg = NoInit()) + + u0, p, success = SciMLBase.get_initial_values( + _prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false); abstol, reltol) + + @test u0 ≈ [1.0, 1.0] + @test p ≈ 1.0 + @test success + end end @testset "Solves with non-integrator value provider" begin _integ = ProblemState(; u = integ.u, p = parameter_values(integ), t = integ.t) u0, p, success = SciMLBase.get_initial_values( prob, _integ, fn, SciMLBase.OverrideInit(), - Val(false); nlsolve_alg = NewtonRaphson()) + Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol) @test u0 ≈ [2.0, 2.0] @test p ≈ 1.0 @@ -133,7 +209,7 @@ end u0, p, success = SciMLBase.get_initial_values( prob, integ, fn, SciMLBase.OverrideInit(), - Val(false); nlsolve_alg = NewtonRaphson()) + Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol) @test u0 ≈ [1.0, 1.0] @test p ≈ 1.0 @test success @@ -147,7 +223,7 @@ end u0, p, success = SciMLBase.get_initial_values( prob, integ, fn, SciMLBase.OverrideInit(), - Val(false); nlsolve_alg = NewtonRaphson()) + Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol) @test u0 ≈ [2.0, 2.0] @test p ≈ 0.0 diff --git a/test/runtests.jl b/test/runtests.jl index 747c7251d..985a71c64 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -121,6 +121,9 @@ end @time @safetestset "Tables interface with MTK" begin include("downstream/tables.jl") end + @time @safetestset "Initialization" begin + include("downstream/initialization.jl") + end end if !is_APPVEYOR && (GROUP == "Downstream" || GROUP == "SymbolicIndexingInterface")