From f4d8e368d248ca0bd4d0e2f73f71bebb92894bed Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 25 Sep 2024 16:44:15 -0400 Subject: [PATCH] add nlfunc to ODEFunction --- src/scimlfunctions.jl | 48 ++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 586ebebb7..98dec2ed0 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -289,6 +289,7 @@ the usage of `f`. These include: based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be internally computed on demand when required. The cost of this operation is highly dependent on the sparsity pattern. +- `nlfunc`: a `NonlinearFunction` ## iip: In-Place vs Out-Of-Place @@ -401,8 +402,8 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, - O, TCV, - SYS, IProb, IProbMap} <: AbstractODEFunction{iip} + O, TCV, SYS, IProb, IProbMap, + NLF} <: AbstractODEFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -421,6 +422,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW sys::SYS initializeprob::IProb initializeprobmap::IProbMap + nlfunc::NLF end @doc doc""" @@ -517,8 +519,8 @@ information on generating the SplitFunction from this symbolic engine. """ struct SplitFunction{ iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, - TPJ, O, - TCV, SYS, IProb, IProbMap} <: AbstractODEFunction{iip} + TPJ, O, TCV, SYS, IProb, IProbMap, + NLF} <: AbstractODEFunction{iip} f1::F1 f2::F2 mass_matrix::TMM @@ -538,6 +540,7 @@ struct SplitFunction{ sys::SYS initializeprob::IProb initializeprobmap::IProbMap + nlfunc::NLF end @doc doc""" @@ -2415,7 +2418,8 @@ function ODEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, - initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing + initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, + nlfunc = __has_nlfunc(f) ? f.nlfunc : nothing, ) where {iip, specialize } @@ -2471,12 +2475,13 @@ function ODEFunction{iip, specialize}(f; Any, Any, Any, Any, Any, Any, Any, typeof(jac_prototype), typeof(sparsity), Any, Any, typeof(W_prototype), Any, - Any, - typeof(_colorvec), - typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac, + Any,typeof(_colorvec), + typeof(sys), Any, Any, + Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap) + observed, _colorvec, sys, initializeprob, initializeprobmap, + nlfunc) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2486,10 +2491,12 @@ function ODEFunction{iip, specialize}(f; typeof(observed), typeof(_colorvec), typeof(sys), typeof(initializeprob), - typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(initializeprobmap), + typeof(nlfunc)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap) + observed, _colorvec, sys, initializeprob, initializeprobmap, + nlfun) else ODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2499,10 +2506,12 @@ function ODEFunction{iip, specialize}(f; typeof(observed), typeof(_colorvec), typeof(sys), typeof(initializeprob), - typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(initializeprobmap), + typeof(nlfunc))}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap) + observed, _colorvec, sys, initializeprob, initializeprobmap, + nlfunc) end end @@ -2519,10 +2528,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) Any, Any, Any, Any, typeof(f.jac_prototype), typeof(f.sparsity), Any, Any, Any, Any, typeof(f.colorvec), - typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.sys), Any, Any, + Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap) + f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap, + f.nlfunc) else ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), @@ -2531,11 +2542,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) typeof(f.paramjac), typeof(f.observed), typeof(f.colorvec), typeof(f.sys), typeof(f.initializeprob), - typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.initializeprobmap), + typeof(f.nlfunc)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, f.observed, f.colorvec, f.sys, f.initializeprob, - f.initializeprobmap) + f.initializeprobmap, + f.nlfunc) end end @@ -4336,6 +4349,7 @@ __has_analytic_full(f) = isdefined(f, :analytic_full) __has_resid_prototype(f) = isdefined(f, :resid_prototype) __has_initializeprob(f) = isdefined(f, :initializeprob) __has_initializeprobmap(f) = isdefined(f, :initializeprobmap) +__has_nlfunc(f) = isdefined(f, :nl_func) # compatibility has_invW(f::AbstractSciMLFunction) = false