From 16ecf435c2dfe8b45aedb3814de8d5e4ecb69d97 Mon Sep 17 00:00:00 2001 From: Stuart Daines Date: Thu, 25 Aug 2022 14:33:42 +0100 Subject: [PATCH 1/2] Robustness fixes: check model and modeldata are consistent - use PB.check_modeldata(model, modeldata) to check consistency of model and modeldata - deprecation warning for initialize!(run) - bugfix for Kinsol wrapper --- Project.toml | 2 +- src/JacobianAD.jl | 8 ++++++-- src/Kinsol.jl | 25 +++++++++---------------- src/ODE.jl | 4 +++- src/ODELocalIMEX.jl | 6 ++++++ src/ODEfixed.jl | 12 ++++++++++-- src/Run.jl | 8 ++++++-- src/SteadyState.jl | 5 +++++ src/SteadyStateKinsol.jl | 2 ++ 9 files changed, 48 insertions(+), 24 deletions(-) diff --git a/Project.toml b/Project.toml index df436ae..474b97e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PALEOmodel" uuid = "bf7b4fbe-ccb1-42c5-83c2-e6e9378b660c" authors = ["Stuart Daines "] -version = "0.15.7" +version = "0.15.8" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" diff --git a/src/JacobianAD.jl b/src/JacobianAD.jl index 468f0cf..4d621c9 100644 --- a/src/JacobianAD.jl +++ b/src/JacobianAD.jl @@ -53,7 +53,9 @@ function jac_config_ode( jac_cellranges=modeldata.cellranges_all, init_logger=Logging.NullLogger(), ) - @info "jac_config_ode: jac_ad=$jac_ad" + @info "jac_config_ode: jac_ad=$jac_ad" + + PB.check_modeldata(model, modeldata) iszero(PALEOmodel.num_total(modeldata.solver_view_all)) || throw(ArgumentError("model contains implicit variables, solve as a DAE")) @@ -158,7 +160,9 @@ function jac_config_dae( implicit_cellranges=modeldata.cellranges_all, init_logger=Logging.NullLogger(), ) - @info "jac_config_dae: jac_ad=$jac_ad" + @info "jac_config_dae: jac_ad=$jac_ad" + + PB.check_modeldata(model, modeldata) # generate arrays with ODE layout for model Variables state_sms_vars_data = similar(PALEOmodel.get_statevar_sms(modeldata.solver_view_all)) diff --git a/src/Kinsol.jl b/src/Kinsol.jl index 886bed3..0e70eb6 100644 --- a/src/Kinsol.jl +++ b/src/Kinsol.jl @@ -11,6 +11,8 @@ module Kinsol import Sundials +# import Infiltrator + ########################################################### # Internals: Julia <-> C wrapper functions ########################################################### @@ -25,24 +27,17 @@ mutable struct UserFunctionAndData{F1, F2, F3, F4} data::Any end -UserFunctionAndData(func, data) = UserFunctionAndData(func, nothing, nothing, nothing, data) -UserFunctionAndData(func) = func -UserFunctionAndData(func, psetup::Nothing, psolve::Nothing, jv::Nothing, data::Nothing) = func - # Julia adaptor function with C types, passed in to Kinsol C code as a callback # wraps C types and forwards to the Julia user function -function kinsolfun(y::Sundials.N_Vector, fy::Sundials.N_Vector, userfun::UserFunctionAndData) - # @Infiltrator.infiltrate +function kinsolfun( + y::Sundials.N_Vector, + fy::Sundials.N_Vector, + userfun::UserFunctionAndData +) userfun.func(convert(Vector, fy), convert(Vector, y), userfun.data) return Sundials.KIN_SUCCESS end -function kinsolfun(y::Sundials.N_Vector, fy::Sundials.N_Vector, userfun) - # @Infiltrator.infiltrate - userfun(convert(Vector, fy), convert(Vector, y)) - return Sundials.KIN_SUCCESS -end - function kinprecsetup( u::Sundials.N_Vector, uscale::Sundials.N_Vector, @@ -69,7 +64,6 @@ function kinprecsolve( v::Sundials.N_Vector, userfun::UserFunctionAndData ) - retval = userfun.psolve( convert(Vector, u), convert(Vector, uscale), @@ -87,7 +81,7 @@ function kinjactimesvec( u::Sundials.N_Vector, new_u::Ptr{Cint}, userfun::UserFunctionAndData -) +) retval = userfun.jv( convert(Vector, v), convert(Vector, Jv), @@ -136,6 +130,7 @@ function kin_create( # use the user_data field to pass a function # see: https://github.com/JuliaLang/julia/issues/2554 userfun = UserFunctionAndData(f, psetupfun, psolvefun, jvfun, userdata) + push!(handles, userfun) # prevent userfun from being garbage collected (required for julia 1.8) function getkinsolfun(userfun::T) where {T} @cfunction(kinsolfun, Cint, (Sundials.N_Vector, Sundials.N_Vector, Ref{T})) end @@ -220,10 +215,8 @@ function kin_solve( flag = Sundials.@checkflag Sundials.KINSetNoInitSetup(kmem, noInitSetup) true ## Solve problem - # @Infiltrator.infiltrate returnflag = Sundials.KINSol(kmem, y, strategy, y_scale, f_scale) - ## Get stats nfevals = [0] flag = Sundials.@checkflag Sundials.KINGetNumFuncEvals(kmem, nfevals) diff --git a/src/ODE.jl b/src/ODE.jl index a704b97..1d357dd 100644 --- a/src/ODE.jl +++ b/src/ODE.jl @@ -52,6 +52,7 @@ function ODEfunction( jac_ad_t_sparsity=nothing, init_logger=Logging.NullLogger(), ) + PB.check_modeldata(model, modeldata) # check for implicit total variables PALEOmodel.num_total(modeldata.solver_view_all) == 0 || @@ -101,8 +102,9 @@ function DAEfunction( jac_ad_t_sparsity=nothing, init_logger=Logging.NullLogger(), ) - @info "DAEfunction: using Jacobian $jac_ad" + + PB.check_modeldata(model, modeldata) jac, jac_prototype, odeimplicit = PALEOmodel.JacobianAD.jac_config_dae( jac_ad, model, initial_state, modeldata, jac_ad_t_sparsity, diff --git a/src/ODELocalIMEX.jl b/src/ODELocalIMEX.jl index 78494e6..024ca8a 100644 --- a/src/ODELocalIMEX.jl +++ b/src/ODELocalIMEX.jl @@ -59,6 +59,8 @@ function integrateLocalIMEXEuler( @info "integrateLocalIMEXEuler: Δt_outer=$Δt_outer (yr)" + PB.check_modeldata(run.model, modeldata) + solver_view_outer = PALEOmodel.create_solver_view(run.model, modeldata, cellranges_outer) @info "solver_view_outer: $(solver_view_outer)" @@ -106,6 +108,8 @@ function timestep_LocalImplicit( deriv_only=false, integrator_barrier=nothing, ) + PB.check_modeldata(model, modeldata) + length(cellranges) == 1 || error("timestep_LocalImplicit only single cellrange supported") cellrange = cellranges[1] @@ -180,6 +184,7 @@ function create_timestep_LocalImplicit_ctxt( niter_max, Lnorm_inf_max ) + PB.check_modeldata(model, modeldata) lictxt = PALEOmodel.ODELocalIMEX.getLocalImplicitContext( model, modeldata, cellrange, exclude_var_nameroots, @@ -196,6 +201,7 @@ function getLocalImplicitContext( request_adchunksize=ForwardDiff.DEFAULT_CHUNK_THRESHOLD, init_logger=Logging.NullLogger(), ) + PB.check_modeldata(model, modeldata) # create SolverViews for first cell, to work out how many dof we need cellrange_cell = PB.CellRange(cellrange.domain, cellrange.operatorID, first(cellrange.indices) ) diff --git a/src/ODEfixed.jl b/src/ODEfixed.jl index 2314f37..279c36d 100644 --- a/src/ODEfixed.jl +++ b/src/ODEfixed.jl @@ -34,6 +34,7 @@ function integrateEuler( outputwriter=run.output, report_interval=1000 ) + PB.check_modeldata(run.model, modeldata) timesteppers = [ [( @@ -96,6 +97,8 @@ function integrateSplitEuler( @info "integrateSplitEuler: Δt_outer=$Δt_outer (yr) n_inner=$n_inner" + PB.check_modeldata(run.model, modeldata) + solver_view_outer = PALEOmodel.create_solver_view(run.model, modeldata, cellranges_outer) @info "solver_view_outer: $(solver_view_outer)" solver_view_inner = PALEOmodel.create_solver_view(run.model, modeldata, cellranges_inner) @@ -159,11 +162,12 @@ function integrateEulerthreads( outputwriter=run.output, report_interval=1000, ) + PB.check_modeldata(run.model, modeldata) nt = Threads.nthreads() nt == 1 || modeldata.threadsafe || error("integrateEulerthreads: Threads.nthreads() = $nt but modeldata is not thread safe "* - "(check initialize!(run::Run, ...))") + "(check initialize!(run.model, ...))") lc = length(cellranges) lc == nt || @@ -235,7 +239,8 @@ function integrateSplitEulerthreads( outputwriter=run.output, report_interval=1000, ) - + PB.check_modeldata(run.model, modeldata) + nt = Threads.nthreads() nt == 1 || modeldata.threadsafe || error("integrateEulerthreads: Threads.nthreads() = $nt but modeldata is not thread safe (check initialize!(run::Run, ...))") @@ -338,6 +343,7 @@ function create_timestep_Euler_ctxt( n_substep=1, verbose=false, ) + PB.check_modeldata(model, modeldata) num_constraints = PALEOmodel.num_algebraic_constraints(solver_view) iszero(num_constraints) || error("DAE problem with $num_constraints algebraic constraints") @@ -368,6 +374,7 @@ function integrateFixed( outputwriter=run.output, report_interval=1000 ) + PB.check_modeldata(run.model, modeldata) nevals = 0 @@ -447,6 +454,7 @@ function integrateFixedthreads( outputwriter=run.output, report_interval=1000 ) + PB.check_modeldata(run.model, modeldata) nevals = 0 diff --git a/src/Run.jl b/src/Run.jl index 60b5c5d..1ea7c62 100644 --- a/src/Run.jl +++ b/src/Run.jl @@ -26,11 +26,15 @@ function Base.show(io::IO, ::MIME"text/plain", run::Run) end -initialize!(run::Run; kwargs...) = initialize!(run.model; kwargs...) +function initialize!(run::Run; kwargs...) + Base.depwarn("call to deprecated initialize!(run::Run; ...), please update your code to use initialize!(run.model; ...)", :initialize!, force=true) + + return initialize!(run.model; kwargs...) +end """ initialize!(model::PB.Model; kwargs...) -> (initial_state::Vector, modeldata::PB.ModelData) - initialize!(run::Run; kwargs...) -> (initial_state::Vector, modeldata::PB.ModelData) + [deprecated] initialize!(run::Run; kwargs...) -> (initial_state::Vector, modeldata::PB.ModelData) Initialize `model` or `run.model` and return: - an `initial_state` Vector diff --git a/src/SteadyState.jl b/src/SteadyState.jl index a7b229b..78ad9ae 100644 --- a/src/SteadyState.jl +++ b/src/SteadyState.jl @@ -53,6 +53,7 @@ function steadystate( use_norm::Bool=false, BLAS_num_threads=1, ) + PB.check_modeldata(run.model, modeldata) LinearAlgebra.BLAS.set_num_threads(BLAS_num_threads) @info "steadystate: using BLAS with $(LinearAlgebra.BLAS.get_num_threads()) threads" @@ -189,6 +190,8 @@ function steadystate_ptc( verbose=false, BLAS_num_threads=1 ) + PB.check_modeldata(run.model, modeldata) + !use_norm || ArgumentError("use_norm=true not supported") nlsolveF = nlsolveF_PTC( @@ -272,6 +275,8 @@ function nlsolveF_PTC( request_adchunksize=10, jac_cellranges=modeldata.cellranges_all, ) + PB.check_modeldata(model, modeldata) + sv = modeldata.solver_view_all # We only support explicit ODE-like configurations (no DAE constraints or implicit variables) diff --git a/src/SteadyStateKinsol.jl b/src/SteadyStateKinsol.jl index bf47853..55d06fd 100644 --- a/src/SteadyStateKinsol.jl +++ b/src/SteadyStateKinsol.jl @@ -62,6 +62,8 @@ function steadystate_ptc( verbose=false, BLAS_num_threads=1 ) + PB.check_modeldata(run.model, modeldata) + # start, end times tss, tss_max = tspan From 3a639638e78999e5b556b3d7189eb0a1fe1a01f4 Mon Sep 17 00:00:00 2001 From: Stuart Daines Date: Thu, 25 Aug 2022 15:37:04 +0100 Subject: [PATCH 2/2] testing Kinsol workarounds --- src/Kinsol.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Kinsol.jl b/src/Kinsol.jl index 0e70eb6..cee1666 100644 --- a/src/Kinsol.jl +++ b/src/Kinsol.jl @@ -81,7 +81,7 @@ function kinjactimesvec( u::Sundials.N_Vector, new_u::Ptr{Cint}, userfun::UserFunctionAndData -) +) retval = userfun.jv( convert(Vector, v), convert(Vector, Jv), @@ -130,7 +130,7 @@ function kin_create( # use the user_data field to pass a function # see: https://github.com/JuliaLang/julia/issues/2554 userfun = UserFunctionAndData(f, psetupfun, psolvefun, jvfun, userdata) - push!(handles, userfun) # prevent userfun from being garbage collected (required for julia 1.8) + # push!(handles, userfun) # TODO prevent userfun from being garbage collected ? function getkinsolfun(userfun::T) where {T} @cfunction(kinsolfun, Cint, (Sundials.N_Vector, Sundials.N_Vector, Ref{T})) end