diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index faaafe9db..ecb49592e 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2661,6 +2661,35 @@ function NonlinearFunction{iip}(f::ODEFunction) where {iip} colorvec = f.colorvec) end +function unwrapped_f(f::NonlinearFunction, newf = unwrapped_f(f.f)) + if specialization(f) === NoSpecialize + return NonlinearFunction{isinplace(f), specialization(f), Any, Any, + Any, Any, Any, Any, Any, Any, Any, + Any, Any, Any, Any, Any, + typeof(f.colorvec), Any, Any}(newf, f.mass_matrix, + f.analytic, f.tgrad, f.jac, + f.jvp, f.vjp, f.jac_prototype, + f.sparsity, f.Wfact, + f.Wfact_t, f.paramjac, + f.observed, f.colorvec, f.sys, + f.resid_prototype) + else + return NonlinearFunction{isinplace(f), specialization(f), typeof(newf), + typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), + typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype), + typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), + typeof(f.paramjac), + typeof(f.observed), typeof(f.colorvec), + typeof(f.sys), typeof(f.resid_prototype)}(newf, f.mass_matrix, + f.analytic, f.tgrad, f.jac, + f.jvp, f.vjp, f.jac_prototype, + f.sparsity, f.Wfact, + f.Wfact_t, f.paramjac, + f.observed, f.colorvec, f.sys, + f.resid_prototype) + end +end + @add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, initializeprob, update_initializeprob!,