Skip to content

Commit

Permalink
Additional initialization path
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 16, 2023
1 parent da5aaff commit eccc06f
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions src/function_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ mutable struct TimeGradientWrapper{iip, fType, uType, P} <: AbstractSciMLFunctio
p::P
end

function TimeGradientWrapper{iip}(f::F, uprev, p) where {F}
return TimeGradientWrapper{iip, F, typeof(uprev), typeof(p)}(f, uprev, p)
end
function TimeGradientWrapper(f::F, uprev, p) where {F}
return TimeGradientWrapper{isinplace(f, 4), F, typeof(uprev), typeof(p)}(f, uprev, p)
return TimeGradientWrapper{isinplace(f, 4)}(f, uprev, p)
end

(ff::TimeGradientWrapper{true})(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2)
Expand All @@ -19,9 +22,10 @@ mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractSciMLFunction{i
p::P
end

function UJacobianWrapper(f::F, t, p) where {F}
return UJacobianWrapper{isinplace(f, 4), F, typeof(t), typeof(p)}(f, t, p)
function UJacobianWrapper{iip}(f::F, t, p) where {F}
return UJacobianWrapper{iip, F, typeof(t), typeof(p)}(f, t, p)
end
UJacobianWrapper(f::F, t, p) where {F} = UJacobianWrapper{isinplace(f, 4)}(f, t, p)

(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t)
(ff::UJacobianWrapper{true})(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1)
Expand All @@ -37,8 +41,11 @@ mutable struct TimeDerivativeWrapper{iip, F, uType, P} <: AbstractSciMLFunction{
p::P
end

function TimeDerivativeWrapper{iip}(f::F, u, p) where {F}
return TimeDerivativeWrapper{iip, F, typeof(u), typeof(p)}(f, u, p)
end
function TimeDerivativeWrapper(f::F, u, p) where {F}
return TimeDerivativeWrapper{isinplace(f, 4), F, typeof(u), typeof(p)}(f, u, p)
return TimeDerivativeWrapper{isinplace(f, 4)}(f, u, p)
end

(ff::TimeDerivativeWrapper{false})(t) = ff.f(ff.u, ff.p, t)
Expand All @@ -51,9 +58,10 @@ mutable struct UDerivativeWrapper{iip, F, tType, P} <: AbstractSciMLFunction{iip
p::P
end

function UDerivativeWrapper(f::F, t, p) where {F}
return UDerivativeWrapper{isinplace(f, 4), F, typeof(t), typeof(p)}(f, t, p)
function UDerivativeWrapper{iip}(f::F, t, p) where {F}
return UDerivativeWrapper{iip, F, typeof(t), typeof(p)}(f, t, p)
end
UDerivativeWrapper(f::F, t, p) where {F} = UDerivativeWrapper{isinplace(f, 4)}(f, t, p)

(ff::UDerivativeWrapper{false})(u) = ff.f(u, ff.p, ff.t)
(ff::UDerivativeWrapper{true})(du1, u) = ff.f(du1, u, ff.p, ff.t)
Expand All @@ -65,9 +73,10 @@ mutable struct ParamJacobianWrapper{iip, fType, tType, uType} <: AbstractSciMLFu
u::uType
end

function ParamJacobianWrapper(f::F, t, u) where {F}
return ParamJacobianWrapper{isinplace(f, 4), F, typeof(t), typeof(u)}(f, t, u)
function ParamJacobianWrapper{iip}(f::F, t, u) where {F}
return ParamJacobianWrapper{iip, F, typeof(t), typeof(u)}(f, t, u)
end
ParamJacobianWrapper(f::F, t, u) where {F} = ParamJacobianWrapper{isinplace(f, 4)}(f, t, u)

(ff::ParamJacobianWrapper{true})(du1, p) = ff.f(du1, ff.u, p, ff.t)
function (ff::ParamJacobianWrapper{true})(p)
Expand All @@ -82,9 +91,9 @@ mutable struct JacobianWrapper{iip, fType, pType} <: AbstractSciMLFunction{iip}
p::pType
end

function JacobianWrapper(f::F, p) where {F}
return JacobianWrapper{isinplace(f, 4), F, typeof(p)}(f, p)
end
JacobianWrapper{iip}(f::F, p) where {F} = JacobianWrapper{iip, F, typeof(p)}(f, p)
JacobianWrapper(f::F, p) where {F} = JacobianWrapper{isinplace(f, 3)}(f, p)

(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p)
(uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p)))
(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)

0 comments on commit eccc06f

Please sign in to comment.