Skip to content

Commit

Permalink
Merge pull request #24 from PALEOtoolkit/solver_robustness
Browse files Browse the repository at this point in the history
Solver robustness fixes
  • Loading branch information
sjdaines authored Aug 25, 2022
2 parents 2d77d31 + 3a63963 commit c9f3936
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PALEOmodel"
uuid = "bf7b4fbe-ccb1-42c5-83c2-e6e9378b660c"
authors = ["Stuart Daines <[email protected]>"]
version = "0.15.7"
version = "0.15.8"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand Down
8 changes: 6 additions & 2 deletions src/JacobianAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ function jac_config_ode(
jac_cellranges=modeldata.cellranges_all,
init_logger=Logging.NullLogger(),
)
@info "jac_config_ode: jac_ad=$jac_ad"
@info "jac_config_ode: jac_ad=$jac_ad"

PB.check_modeldata(model, modeldata)

iszero(PALEOmodel.num_total(modeldata.solver_view_all)) ||
throw(ArgumentError("model contains implicit variables, solve as a DAE"))
Expand Down Expand Up @@ -158,7 +160,9 @@ function jac_config_dae(
implicit_cellranges=modeldata.cellranges_all,
init_logger=Logging.NullLogger(),
)
@info "jac_config_dae: jac_ad=$jac_ad"
@info "jac_config_dae: jac_ad=$jac_ad"

PB.check_modeldata(model, modeldata)

# generate arrays with ODE layout for model Variables
state_sms_vars_data = similar(PALEOmodel.get_statevar_sms(modeldata.solver_view_all))
Expand Down
25 changes: 9 additions & 16 deletions src/Kinsol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ module Kinsol

import Sundials

# import Infiltrator

###########################################################
# Internals: Julia <-> C wrapper functions
###########################################################
Expand All @@ -25,24 +27,17 @@ mutable struct UserFunctionAndData{F1, F2, F3, F4}
data::Any
end

UserFunctionAndData(func, data) = UserFunctionAndData(func, nothing, nothing, nothing, data)
UserFunctionAndData(func) = func
UserFunctionAndData(func, psetup::Nothing, psolve::Nothing, jv::Nothing, data::Nothing) = func

# Julia adaptor function with C types, passed in to Kinsol C code as a callback
# wraps C types and forwards to the Julia user function
function kinsolfun(y::Sundials.N_Vector, fy::Sundials.N_Vector, userfun::UserFunctionAndData)
# @Infiltrator.infiltrate
function kinsolfun(
y::Sundials.N_Vector,
fy::Sundials.N_Vector,
userfun::UserFunctionAndData
)
userfun.func(convert(Vector, fy), convert(Vector, y), userfun.data)
return Sundials.KIN_SUCCESS
end

function kinsolfun(y::Sundials.N_Vector, fy::Sundials.N_Vector, userfun)
# @Infiltrator.infiltrate
userfun(convert(Vector, fy), convert(Vector, y))
return Sundials.KIN_SUCCESS
end

function kinprecsetup(
u::Sundials.N_Vector,
uscale::Sundials.N_Vector,
Expand All @@ -69,7 +64,6 @@ function kinprecsolve(
v::Sundials.N_Vector,
userfun::UserFunctionAndData
)

retval = userfun.psolve(
convert(Vector, u),
convert(Vector, uscale),
Expand All @@ -87,7 +81,7 @@ function kinjactimesvec(
u::Sundials.N_Vector,
new_u::Ptr{Cint},
userfun::UserFunctionAndData
)
)
retval = userfun.jv(
convert(Vector, v),
convert(Vector, Jv),
Expand Down Expand Up @@ -136,6 +130,7 @@ function kin_create(
# use the user_data field to pass a function
# see: https://github.com/JuliaLang/julia/issues/2554
userfun = UserFunctionAndData(f, psetupfun, psolvefun, jvfun, userdata)
# push!(handles, userfun) # TODO prevent userfun from being garbage collected ?
function getkinsolfun(userfun::T) where {T}
@cfunction(kinsolfun, Cint, (Sundials.N_Vector, Sundials.N_Vector, Ref{T}))
end
Expand Down Expand Up @@ -220,10 +215,8 @@ function kin_solve(
flag = Sundials.@checkflag Sundials.KINSetNoInitSetup(kmem, noInitSetup) true

## Solve problem
# @Infiltrator.infiltrate
returnflag = Sundials.KINSol(kmem, y, strategy, y_scale, f_scale)


## Get stats
nfevals = [0]
flag = Sundials.@checkflag Sundials.KINGetNumFuncEvals(kmem, nfevals)
Expand Down
4 changes: 3 additions & 1 deletion src/ODE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ function ODEfunction(
jac_ad_t_sparsity=nothing,
init_logger=Logging.NullLogger(),
)
PB.check_modeldata(model, modeldata)

# check for implicit total variables
PALEOmodel.num_total(modeldata.solver_view_all) == 0 ||
Expand Down Expand Up @@ -101,8 +102,9 @@ function DAEfunction(
jac_ad_t_sparsity=nothing,
init_logger=Logging.NullLogger(),
)

@info "DAEfunction: using Jacobian $jac_ad"

PB.check_modeldata(model, modeldata)

jac, jac_prototype, odeimplicit = PALEOmodel.JacobianAD.jac_config_dae(
jac_ad, model, initial_state, modeldata, jac_ad_t_sparsity,
Expand Down
6 changes: 6 additions & 0 deletions src/ODELocalIMEX.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ function integrateLocalIMEXEuler(

@info "integrateLocalIMEXEuler: Δt_outer=$Δt_outer (yr)"

PB.check_modeldata(run.model, modeldata)

solver_view_outer = PALEOmodel.create_solver_view(run.model, modeldata, cellranges_outer)
@info "solver_view_outer: $(solver_view_outer)"

Expand Down Expand Up @@ -106,6 +108,8 @@ function timestep_LocalImplicit(
deriv_only=false,
integrator_barrier=nothing,
)
PB.check_modeldata(model, modeldata)

length(cellranges) == 1 || error("timestep_LocalImplicit only single cellrange supported")
cellrange = cellranges[1]

Expand Down Expand Up @@ -180,6 +184,7 @@ function create_timestep_LocalImplicit_ctxt(
niter_max,
Lnorm_inf_max
)
PB.check_modeldata(model, modeldata)

lictxt = PALEOmodel.ODELocalIMEX.getLocalImplicitContext(
model, modeldata, cellrange, exclude_var_nameroots,
Expand All @@ -196,6 +201,7 @@ function getLocalImplicitContext(
request_adchunksize=ForwardDiff.DEFAULT_CHUNK_THRESHOLD,
init_logger=Logging.NullLogger(),
)
PB.check_modeldata(model, modeldata)

# create SolverViews for first cell, to work out how many dof we need
cellrange_cell = PB.CellRange(cellrange.domain, cellrange.operatorID, first(cellrange.indices) )
Expand Down
12 changes: 10 additions & 2 deletions src/ODEfixed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ function integrateEuler(
outputwriter=run.output,
report_interval=1000
)
PB.check_modeldata(run.model, modeldata)

timesteppers = [
[(
Expand Down Expand Up @@ -96,6 +97,8 @@ function integrateSplitEuler(

@info "integrateSplitEuler: Δt_outer=$Δt_outer (yr) n_inner=$n_inner"

PB.check_modeldata(run.model, modeldata)

solver_view_outer = PALEOmodel.create_solver_view(run.model, modeldata, cellranges_outer)
@info "solver_view_outer: $(solver_view_outer)"
solver_view_inner = PALEOmodel.create_solver_view(run.model, modeldata, cellranges_inner)
Expand Down Expand Up @@ -159,11 +162,12 @@ function integrateEulerthreads(
outputwriter=run.output,
report_interval=1000,
)
PB.check_modeldata(run.model, modeldata)

nt = Threads.nthreads()
nt == 1 || modeldata.threadsafe ||
error("integrateEulerthreads: Threads.nthreads() = $nt but modeldata is not thread safe "*
"(check initialize!(run::Run, ...))")
"(check initialize!(run.model, ...))")

lc = length(cellranges)
lc == nt ||
Expand Down Expand Up @@ -235,7 +239,8 @@ function integrateSplitEulerthreads(
outputwriter=run.output,
report_interval=1000,
)

PB.check_modeldata(run.model, modeldata)

nt = Threads.nthreads()
nt == 1 || modeldata.threadsafe ||
error("integrateEulerthreads: Threads.nthreads() = $nt but modeldata is not thread safe (check initialize!(run::Run, ...))")
Expand Down Expand Up @@ -338,6 +343,7 @@ function create_timestep_Euler_ctxt(
n_substep=1,
verbose=false,
)
PB.check_modeldata(model, modeldata)

num_constraints = PALEOmodel.num_algebraic_constraints(solver_view)
iszero(num_constraints) || error("DAE problem with $num_constraints algebraic constraints")
Expand Down Expand Up @@ -368,6 +374,7 @@ function integrateFixed(
outputwriter=run.output,
report_interval=1000
)
PB.check_modeldata(run.model, modeldata)

nevals = 0

Expand Down Expand Up @@ -447,6 +454,7 @@ function integrateFixedthreads(
outputwriter=run.output,
report_interval=1000
)
PB.check_modeldata(run.model, modeldata)

nevals = 0

Expand Down
8 changes: 6 additions & 2 deletions src/Run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ function Base.show(io::IO, ::MIME"text/plain", run::Run)
end


initialize!(run::Run; kwargs...) = initialize!(run.model; kwargs...)
function initialize!(run::Run; kwargs...)
Base.depwarn("call to deprecated initialize!(run::Run; ...), please update your code to use initialize!(run.model; ...)", :initialize!, force=true)

return initialize!(run.model; kwargs...)
end

"""
initialize!(model::PB.Model; kwargs...) -> (initial_state::Vector, modeldata::PB.ModelData)
initialize!(run::Run; kwargs...) -> (initial_state::Vector, modeldata::PB.ModelData)
[deprecated] initialize!(run::Run; kwargs...) -> (initial_state::Vector, modeldata::PB.ModelData)
Initialize `model` or `run.model` and return:
- an `initial_state` Vector
Expand Down
5 changes: 5 additions & 0 deletions src/SteadyState.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ function steadystate(
use_norm::Bool=false,
BLAS_num_threads=1,
)
PB.check_modeldata(run.model, modeldata)

LinearAlgebra.BLAS.set_num_threads(BLAS_num_threads)
@info "steadystate: using BLAS with $(LinearAlgebra.BLAS.get_num_threads()) threads"
Expand Down Expand Up @@ -189,6 +190,8 @@ function steadystate_ptc(
verbose=false,
BLAS_num_threads=1
)
PB.check_modeldata(run.model, modeldata)

!use_norm || ArgumentError("use_norm=true not supported")

nlsolveF = nlsolveF_PTC(
Expand Down Expand Up @@ -272,6 +275,8 @@ function nlsolveF_PTC(
request_adchunksize=10,
jac_cellranges=modeldata.cellranges_all,
)
PB.check_modeldata(model, modeldata)

sv = modeldata.solver_view_all

# We only support explicit ODE-like configurations (no DAE constraints or implicit variables)
Expand Down
2 changes: 2 additions & 0 deletions src/SteadyStateKinsol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ function steadystate_ptc(
verbose=false,
BLAS_num_threads=1
)
PB.check_modeldata(run.model, modeldata)

# start, end times
tss, tss_max = tspan

Expand Down

2 comments on commit c9f3936

@sjdaines
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/67060

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.15.8 -m "<description of version>" c9f3936ea5d0e74fd3c74d572a4625cc45d99e54
git push origin v0.15.8

Please sign in to comment.