Skip to content

Commit

Permalink
Merge pull request #543 from SciML/ap/propagate_iip
Browse files Browse the repository at this point in the history
Propagate IIP information in the Wrapper Functions
  • Loading branch information
ChrisRackauckas authored Nov 17, 2023
2 parents ece1966 + d7f5284 commit f6b059c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 21 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "2.8.1"
version = "2.8.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -69,6 +69,8 @@ PartialFunctions = "1.1"
PrecompileTools = "1"
Preferences = "1.3"
Printf = "1.9"
PyCall = "1.96"
PythonCall = "0.9"
RCall = "0.13.18"
RecipesBase = "0.7.0, 0.8, 1.0"
RecursiveArrayTools = "2.33"
Expand Down
81 changes: 62 additions & 19 deletions src/function_wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,99 @@
mutable struct TimeGradientWrapper{fType, uType, P} <: Function
mutable struct TimeGradientWrapper{iip, fType, uType, P} <: AbstractSciMLFunction{iip}
f::fType
uprev::uType
p::P
end
(ff::TimeGradientWrapper)(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2)
(ff::TimeGradientWrapper)(du2, t) = ff.f(du2, ff.uprev, ff.p, t)

mutable struct UJacobianWrapper{fType, tType, P} <: Function
function TimeGradientWrapper{iip}(f::F, uprev, p) where {F, iip}
return TimeGradientWrapper{iip, F, typeof(uprev), typeof(p)}(f, uprev, p)
end
function TimeGradientWrapper(f::F, uprev, p) where {F}
return TimeGradientWrapper{isinplace(f, 4)}(f, uprev, p)
end

(ff::TimeGradientWrapper{true})(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2)
(ff::TimeGradientWrapper{true})(du2, t) = ff.f(du2, ff.uprev, ff.p, t)

(ff::TimeGradientWrapper{false})(t) = ff.f(ff.uprev, ff.p, t)

mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractSciMLFunction{iip}
f::fType
t::tType
p::P
end

(ff::UJacobianWrapper)(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t)
(ff::UJacobianWrapper)(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1)
(ff::UJacobianWrapper)(du1, uprev, p, t) = ff.f(du1, uprev, p, t)
(ff::UJacobianWrapper)(uprev, p, t) = (du1 = similar(uprev); ff.f(du1, uprev, p, t); du1)
function UJacobianWrapper{iip}(f::F, t, p) where {F, iip}
return UJacobianWrapper{iip, F, typeof(t), typeof(p)}(f, t, p)
end
UJacobianWrapper(f::F, t, p) where {F} = UJacobianWrapper{isinplace(f, 4)}(f, t, p)

(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t)
(ff::UJacobianWrapper{true})(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1)
(ff::UJacobianWrapper{true})(du1, uprev, p, t) = ff.f(du1, uprev, p, t)
(ff::UJacobianWrapper{true})(uprev, p, t) = (du1 = similar(uprev); ff.f(du1, uprev, p, t); du1)

mutable struct TimeDerivativeWrapper{F, uType, P} <: Function
(ff::UJacobianWrapper{false})(uprev) = ff.f(uprev, ff.p, ff.t)
(ff::UJacobianWrapper{false})(uprev, p, t) = ff.f(uprev, p, t)

mutable struct TimeDerivativeWrapper{iip, F, uType, P} <: AbstractSciMLFunction{iip}
f::F
u::uType
p::P
end
(ff::TimeDerivativeWrapper)(t) = ff.f(ff.u, ff.p, t)

mutable struct UDerivativeWrapper{F, tType, P} <: Function
function TimeDerivativeWrapper{iip}(f::F, u, p) where {F, iip}
return TimeDerivativeWrapper{iip, F, typeof(u), typeof(p)}(f, u, p)
end
function TimeDerivativeWrapper(f::F, u, p) where {F}
return TimeDerivativeWrapper{isinplace(f, 4)}(f, u, p)
end

(ff::TimeDerivativeWrapper{false})(t) = ff.f(ff.u, ff.p, t)
(ff::TimeDerivativeWrapper{true})(du1, t) = ff.f(du1, ff.u, ff.p, t)
(ff::TimeDerivativeWrapper{true})(t) = (du1 = similar(ff.u); ff.f(du1, ff.u, ff.p, t); du1)

mutable struct UDerivativeWrapper{iip, F, tType, P} <: AbstractSciMLFunction{iip}
f::F
t::tType
p::P
end
(ff::UDerivativeWrapper)(u) = ff.f(u, ff.p, ff.t)

mutable struct ParamJacobianWrapper{fType, tType, uType} <: Function
function UDerivativeWrapper{iip}(f::F, t, p) where {F, iip}
return UDerivativeWrapper{iip, F, typeof(t), typeof(p)}(f, t, p)
end
UDerivativeWrapper(f::F, t, p) where {F} = UDerivativeWrapper{isinplace(f, 4)}(f, t, p)

(ff::UDerivativeWrapper{false})(u) = ff.f(u, ff.p, ff.t)
(ff::UDerivativeWrapper{true})(du1, u) = ff.f(du1, u, ff.p, ff.t)
(ff::UDerivativeWrapper{true})(u) = (du1 = similar(u); ff.f(du1, u, ff.p, ff.t); du1)

mutable struct ParamJacobianWrapper{iip, fType, tType, uType} <: AbstractSciMLFunction{iip}
f::fType
t::tType
u::uType
end

function (ff::ParamJacobianWrapper)(du1, p)
ff.f(du1, ff.u, p, ff.t)
function ParamJacobianWrapper{iip}(f::F, t, u) where {F, iip}
return ParamJacobianWrapper{iip, F, typeof(t), typeof(u)}(f, t, u)
end
ParamJacobianWrapper(f::F, t, u) where {F} = ParamJacobianWrapper{isinplace(f, 4)}(f, t, u)

function (ff::ParamJacobianWrapper)(p)
(ff::ParamJacobianWrapper{true})(du1, p) = ff.f(du1, ff.u, p, ff.t)
function (ff::ParamJacobianWrapper{true})(p)
du1 = similar(p, size(ff.u))
ff.f(du1, ff.u, p, ff.t)
return du1
end
(ff::ParamJacobianWrapper{false})(p) = ff.f(ff.u, p, ff.t)

mutable struct JacobianWrapper{fType, pType}
mutable struct JacobianWrapper{iip, fType, pType} <: AbstractSciMLFunction{iip}
f::fType
p::pType
end

(uf::JacobianWrapper)(u) = uf.f(u, uf.p)
(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p)
JacobianWrapper{iip}(f::F, p) where {F, iip} = JacobianWrapper{iip, F, typeof(p)}(f, p)
JacobianWrapper(f::F, p) where {F} = JacobianWrapper{isinplace(f, 3)}(f, p)

(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p)
(uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p)))
(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)
2 changes: 1 addition & 1 deletion test/aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end
@testset "Aqua tests (additional)" begin
Aqua.test_undefined_exports(SciMLBase)
Aqua.test_stale_deps(SciMLBase)
Aqua.test_deps_compat(SciMLBase)
Aqua.test_deps_compat(SciMLBase, check_extras = false)
Aqua.test_project_extras(SciMLBase)
# Aqua.test_project_toml_formatting(SciMLBase) # failing
# Aqua.test_piracy(SciMLBase) # failing
Expand Down

0 comments on commit f6b059c

Please sign in to comment.