From 26d5311ecfc1f0c5ce2a5a0fe2a13cf048c41e7b Mon Sep 17 00:00:00 2001 From: Daines Date: Mon, 29 Aug 2022 10:49:17 +0100 Subject: [PATCH] Test speculative fix/workaround for kinsol segfault See https://github.com/PALEOtoolkit/PALEOmodel.jl/issues/25 This rearranges the kinsol wrapper to a simpler form --- src/Kinsol.jl | 53 +++++++++++++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/src/Kinsol.jl b/src/Kinsol.jl index cee1666..69cd834 100644 --- a/src/Kinsol.jl +++ b/src/Kinsol.jl @@ -120,31 +120,31 @@ function kin_create( psolvefun = nothing, jvfun = nothing, ) + # use the user_data field to pass a function + # see: https://github.com/JuliaLang/julia/issues/2554 + userfun = UserFunctionAndData(f, psetupfun, psolvefun, jvfun, userdata) + + return _kin_create(userfun, y0; linear_solver=linear_solver, jac_upper=jac_upper, jac_lower=jac_lower, krylov_dim=krylov_dim) +end + +function _kin_create( + userfun::T, y0::Vector{Float64}; + linear_solver, + jac_upper, + jac_lower, + krylov_dim, +) where {T} mem_ptr = Sundials.KINCreate() (mem_ptr == C_NULL) && error("Failed to allocate KINSOL solver object") kmem = Sundials.Handle(mem_ptr) handles = [] + + push!(handles, userfun) # TODO prevent userfun from being garbage collected ? - # use the user_data field to pass a function - # see: https://github.com/JuliaLang/julia/issues/2554 - userfun = UserFunctionAndData(f, psetupfun, psolvefun, jvfun, userdata) - # push!(handles, userfun) # TODO prevent userfun from being garbage collected ? - function getkinsolfun(userfun::T) where {T} - @cfunction(kinsolfun, Cint, (Sundials.N_Vector, Sundials.N_Vector, Ref{T})) - end - function getpsetupfun(userfun::T) where {T} - @cfunction(kinprecsetup, Cint, (Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ref{T})) - end - function getpsolvefun(userfun::T) where {T} - @cfunction(kinprecsolve, Cint, (Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ref{T})) - end - function getkinjactimesvec(userfun::T) where {T} - @cfunction(kinjactimesvec, Cint, (Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ptr{Cint}, Ref{T})) - end - - flag = Sundials.@checkflag Sundials.KINInit(kmem, getkinsolfun(userfun), Sundials.NVector(y0)) true + c_kinsolfun = @cfunction(kinsolfun, Cint, (Sundials.N_Vector, Sundials.N_Vector, Ref{T})) + flag = Sundials.@checkflag Sundials.KINInit(kmem, c_kinsolfun, Sundials.NVector(y0)) true if linear_solver == :Dense A = Sundials.SUNDenseMatrix(length(y0), length(y0)) @@ -158,7 +158,7 @@ function kin_create( push!(handles, Sundials.LinSolHandle(LS, Sundials.Band())) elseif linear_solver == :FGMRES A = nothing - prec_side = isnothing(psolvefun) ? 0 : 2 # right preconditioning only + prec_side = isnothing(userfun.psolve) ? 0 : 2 # right preconditioning only LS = Sundials.SUNLinSol_SPFGMR(y0, prec_side, krylov_dim) push!(handles, Sundials.LinSolHandle(LS, Sundials.SPFGMR())) end @@ -168,17 +168,20 @@ function kin_create( flag = Sundials.@checkflag Sundials.KINSetLinearSolver(kmem, LS, A === nothing ? C_NULL : A) true # flag = Sundials.@checkflag Sundials.KINDlsSetLinearSolver(kmem, LS, A === nothing ? C_NULL : A) true - if !isnothing(psolvefun) + if !isnothing(userfun.psolve) + c_kinprecsetup = isnothing(userfun.psetup) ? C_NULL : @cfunction(kinprecsetup, Cint, (Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ref{T})) + c_kinprecsolve = @cfunction(kinprecsolve, Cint, (Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ref{T})) + flag = Sundials.@checkflag Sundials.KINSetPreconditioner(kmem, - psetupfun === nothing ? C_NULL : getpsetupfun(userfun), - getpsolvefun(userfun)) true + c_kinprecsetup, + c_kinprecsolve) true end - if !isnothing(jvfun) - flag = Sundials.@checkflag Sundials.KINSetJacTimesVecFn(kmem, getkinjactimesvec(userfun)) true + if !isnothing(userfun.jv) + c_kinjactimesvec = @cfunction(kinjactimesvec, Cint, (Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ptr{Cint}, Ref{T})) + flag = Sundials.@checkflag Sundials.KINSetJacTimesVecFn(kmem, c_kinjactimesvec) true end - return (;kmem, handles) end