From baaad84ba06a1616fd63cd2a0544cfa989b7274a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 15 Nov 2023 19:55:37 -0500 Subject: [PATCH] Additional initialization path --- src/function_wrappers.jl | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/function_wrappers.jl b/src/function_wrappers.jl index 967114ffb..93f5394c4 100644 --- a/src/function_wrappers.jl +++ b/src/function_wrappers.jl @@ -4,8 +4,11 @@ mutable struct TimeGradientWrapper{iip, fType, uType, P} <: AbstractSciMLFunctio p::P end +function TimeGradientWrapper{iip}(f::F, uprev, p) where {F, iip} + return TimeGradientWrapper{iip, F, typeof(uprev), typeof(p)}(f, uprev, p) +end function TimeGradientWrapper(f::F, uprev, p) where {F} - return TimeGradientWrapper{isinplace(f, 4), F, typeof(uprev), typeof(p)}(f, uprev, p) + return TimeGradientWrapper{isinplace(f, 4)}(f, uprev, p) end (ff::TimeGradientWrapper{true})(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2) @@ -19,9 +22,10 @@ mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractSciMLFunction{i p::P end -function UJacobianWrapper(f::F, t, p) where {F} - return UJacobianWrapper{isinplace(f, 4), F, typeof(t), typeof(p)}(f, t, p) +function UJacobianWrapper{iip}(f::F, t, p) where {F, iip} + return UJacobianWrapper{iip, F, typeof(t), typeof(p)}(f, t, p) end +UJacobianWrapper(f::F, t, p) where {F} = UJacobianWrapper{isinplace(f, 4)}(f, t, p) (ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t) (ff::UJacobianWrapper{true})(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1) @@ -37,8 +41,11 @@ mutable struct TimeDerivativeWrapper{iip, F, uType, P} <: AbstractSciMLFunction{ p::P end +function TimeDerivativeWrapper{iip}(f::F, u, p) where {F, iip} + return TimeDerivativeWrapper{iip, F, typeof(u), typeof(p)}(f, u, p) +end function TimeDerivativeWrapper(f::F, u, p) where {F} - return TimeDerivativeWrapper{isinplace(f, 4), F, typeof(u), typeof(p)}(f, u, p) + return TimeDerivativeWrapper{isinplace(f, 4)}(f, u, p) end (ff::TimeDerivativeWrapper{false})(t) = ff.f(ff.u, ff.p, t) @@ -51,9 +58,10 @@ mutable struct UDerivativeWrapper{iip, F, tType, P} <: AbstractSciMLFunction{iip p::P end -function UDerivativeWrapper(f::F, t, p) where {F} - return UDerivativeWrapper{isinplace(f, 4), F, typeof(t), typeof(p)}(f, t, p) +function UDerivativeWrapper{iip}(f::F, t, p) where {F, iip} + return UDerivativeWrapper{iip, F, typeof(t), typeof(p)}(f, t, p) end +UDerivativeWrapper(f::F, t, p) where {F} = UDerivativeWrapper{isinplace(f, 4)}(f, t, p) (ff::UDerivativeWrapper{false})(u) = ff.f(u, ff.p, ff.t) (ff::UDerivativeWrapper{true})(du1, u) = ff.f(du1, u, ff.p, ff.t) @@ -65,9 +73,10 @@ mutable struct ParamJacobianWrapper{iip, fType, tType, uType} <: AbstractSciMLFu u::uType end -function ParamJacobianWrapper(f::F, t, u) where {F} - return ParamJacobianWrapper{isinplace(f, 4), F, typeof(t), typeof(u)}(f, t, u) +function ParamJacobianWrapper{iip}(f::F, t, u) where {F, iip} + return ParamJacobianWrapper{iip, F, typeof(t), typeof(u)}(f, t, u) end +ParamJacobianWrapper(f::F, t, u) where {F} = ParamJacobianWrapper{isinplace(f, 4)}(f, t, u) (ff::ParamJacobianWrapper{true})(du1, p) = ff.f(du1, ff.u, p, ff.t) function (ff::ParamJacobianWrapper{true})(p) @@ -82,9 +91,9 @@ mutable struct JacobianWrapper{iip, fType, pType} <: AbstractSciMLFunction{iip} p::pType end -function JacobianWrapper(f::F, p) where {F} - return JacobianWrapper{isinplace(f, 4), F, typeof(p)}(f, p) -end +JacobianWrapper{iip}(f::F, p) where {F, iip} = JacobianWrapper{iip, F, typeof(p)}(f, p) +JacobianWrapper(f::F, p) where {F} = JacobianWrapper{isinplace(f, 3)}(f, p) (uf::JacobianWrapper{false})(u) = uf.f(u, uf.p) +(uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p))) (uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)