From da5aaff6009f9fdc6021878977e193c75eb008f6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 15 Nov 2023 16:07:45 -0500 Subject: [PATCH] Propagate IIP information in the Wrapper Functions --- Project.toml | 2 +- src/function_wrappers.jl | 72 +++++++++++++++++++++++++++++----------- 2 files changed, 54 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index a60a6bc76..543a02a8b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.8.1" +version = "2.8.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/function_wrappers.jl b/src/function_wrappers.jl index e97b37b18..967114ffb 100644 --- a/src/function_wrappers.jl +++ b/src/function_wrappers.jl @@ -1,56 +1,90 @@ -mutable struct TimeGradientWrapper{fType, uType, P} <: Function +mutable struct TimeGradientWrapper{iip, fType, uType, P} <: AbstractSciMLFunction{iip} f::fType uprev::uType p::P end -(ff::TimeGradientWrapper)(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2) -(ff::TimeGradientWrapper)(du2, t) = ff.f(du2, ff.uprev, ff.p, t) -mutable struct UJacobianWrapper{fType, tType, P} <: Function +function TimeGradientWrapper(f::F, uprev, p) where {F} + return TimeGradientWrapper{isinplace(f, 4), F, typeof(uprev), typeof(p)}(f, uprev, p) +end + +(ff::TimeGradientWrapper{true})(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2) +(ff::TimeGradientWrapper{true})(du2, t) = ff.f(du2, ff.uprev, ff.p, t) + +(ff::TimeGradientWrapper{false})(t) = ff.f(ff.uprev, ff.p, t) + +mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractSciMLFunction{iip} f::fType t::tType p::P end -(ff::UJacobianWrapper)(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t) -(ff::UJacobianWrapper)(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1) -(ff::UJacobianWrapper)(du1, uprev, p, t) = ff.f(du1, uprev, p, t) -(ff::UJacobianWrapper)(uprev, p, t) = (du1 = similar(uprev); ff.f(du1, uprev, p, t); du1) +function UJacobianWrapper(f::F, t, p) where {F} + return UJacobianWrapper{isinplace(f, 4), F, typeof(t), typeof(p)}(f, t, p) +end + +(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t) +(ff::UJacobianWrapper{true})(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1) +(ff::UJacobianWrapper{true})(du1, uprev, p, t) = ff.f(du1, uprev, p, t) +(ff::UJacobianWrapper{true})(uprev, p, t) = (du1 = similar(uprev); ff.f(du1, uprev, p, t); du1) -mutable struct TimeDerivativeWrapper{F, uType, P} <: Function +(ff::UJacobianWrapper{false})(uprev) = ff.f(uprev, ff.p, ff.t) +(ff::UJacobianWrapper{false})(uprev, p, t) = ff.f(uprev, p, t) + +mutable struct TimeDerivativeWrapper{iip, F, uType, P} <: AbstractSciMLFunction{iip} f::F u::uType p::P end -(ff::TimeDerivativeWrapper)(t) = ff.f(ff.u, ff.p, t) -mutable struct UDerivativeWrapper{F, tType, P} <: Function +function TimeDerivativeWrapper(f::F, u, p) where {F} + return TimeDerivativeWrapper{isinplace(f, 4), F, typeof(u), typeof(p)}(f, u, p) +end + +(ff::TimeDerivativeWrapper{false})(t) = ff.f(ff.u, ff.p, t) +(ff::TimeDerivativeWrapper{true})(du1, t) = ff.f(du1, ff.u, ff.p, t) +(ff::TimeDerivativeWrapper{true})(t) = (du1 = similar(ff.u); ff.f(du1, ff.u, ff.p, t); du1) + +mutable struct UDerivativeWrapper{iip, F, tType, P} <: AbstractSciMLFunction{iip} f::F t::tType p::P end -(ff::UDerivativeWrapper)(u) = ff.f(u, ff.p, ff.t) -mutable struct ParamJacobianWrapper{fType, tType, uType} <: Function +function UDerivativeWrapper(f::F, t, p) where {F} + return UDerivativeWrapper{isinplace(f, 4), F, typeof(t), typeof(p)}(f, t, p) +end + +(ff::UDerivativeWrapper{false})(u) = ff.f(u, ff.p, ff.t) +(ff::UDerivativeWrapper{true})(du1, u) = ff.f(du1, u, ff.p, ff.t) +(ff::UDerivativeWrapper{true})(u) = (du1 = similar(u); ff.f(du1, u, ff.p, ff.t); du1) + +mutable struct ParamJacobianWrapper{iip, fType, tType, uType} <: AbstractSciMLFunction{iip} f::fType t::tType u::uType end -function (ff::ParamJacobianWrapper)(du1, p) - ff.f(du1, ff.u, p, ff.t) +function ParamJacobianWrapper(f::F, t, u) where {F} + return ParamJacobianWrapper{isinplace(f, 4), F, typeof(t), typeof(u)}(f, t, u) end -function (ff::ParamJacobianWrapper)(p) +(ff::ParamJacobianWrapper{true})(du1, p) = ff.f(du1, ff.u, p, ff.t) +function (ff::ParamJacobianWrapper{true})(p) du1 = similar(p, size(ff.u)) ff.f(du1, ff.u, p, ff.t) return du1 end +(ff::ParamJacobianWrapper{false})(p) = ff.f(ff.u, p, ff.t) -mutable struct JacobianWrapper{fType, pType} +mutable struct JacobianWrapper{iip, fType, pType} <: AbstractSciMLFunction{iip} f::fType p::pType end -(uf::JacobianWrapper)(u) = uf.f(u, uf.p) -(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p) +function JacobianWrapper(f::F, p) where {F} + return JacobianWrapper{isinplace(f, 4), F, typeof(p)}(f, p) +end + +(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p) +(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)